"""Early stopping""" import datetime import torch __all__ = ['EarlyStopping'] class EarlyStopping(object): """Early stop tracker Save model checkpoint when observing a performance improvement on the validation set and early stop if improvement has not been observed for a particular number of epochs. Parameters ---------- mode : str * 'higher': Higher metric suggests a better model * 'lower': Lower metric suggests a better model patience : int The early stopping will happen if we do not observe performance improvement for ``patience`` consecutive epochs. filename : str or None Filename for storing the model checkpoint """ def __init__(self, mode='higher', patience=10, filename=None): if filename is None: dt = datetime.datetime.now() filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format( dt.date(), dt.hour, dt.minute, dt.second) assert mode in ['higher', 'lower'] self.mode = mode if self.mode == 'higher': self._check = self._check_higher else: self._check = self._check_lower self.patience = patience self.counter = 0 self.filename = filename self.best_score = None self.early_stop = False def _check_higher(self, score, prev_best_score): """Check if the new score is higher than the previous best score. Parameters ---------- score : float New score. prev_best_score : float Previous best score. Returns ------- bool Whether the new score is higher than the previous best score. """ return (score > prev_best_score) def _check_lower(self, score, prev_best_score): """Check if the new score is lower than the previous best score. Parameters ---------- score : float New score. prev_best_score : float Previous best score. Returns ------- bool Whether the new score is lower than the previous best score. """ return (score < prev_best_score) def step(self, score, model): """Update based on a new score. The new score is typically model performance on the validation set for a new epoch. Parameters ---------- score : float New score. model : nn.Module Model instance. Returns ------- bool Whether an early stop should be performed. """ if self.best_score is None: self.best_score = score self.save_checkpoint(model) elif self._check(score, self.best_score): self.best_score = score self.save_checkpoint(model) self.counter = 0 else: self.counter += 1 print( f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True return self.early_stop def save_checkpoint(self, model): '''Saves model when the metric on the validation set gets improved. Parameters ---------- model : nn.Module Model instance. ''' torch.save({'model_state_dict': model.state_dict()}, self.filename) def load_checkpoint(self, model): '''Load the latest checkpoint Parameters ---------- model : nn.Module Model instance. ''' model.load_state_dict(torch.load(self.filename)['model_state_dict'])