"""Early stopping""" # pylint: disable= no-member, arguments-differ, invalid-name import datetime import torch __all__ = ['EarlyStopping'] # pylint: disable=C0103 class EarlyStopping(object): """Early stop tracker Save model checkpoint when observing a performance improvement on the validation set and early stop if improvement has not been observed for a particular number of epochs. Parameters ---------- mode : str * 'higher': Higher metric suggests a better model * 'lower': Lower metric suggests a better model patience : int The early stopping will happen if we do not observe performance improvement for ``patience`` consecutive epochs. filename : str or None Filename for storing the model checkpoint. If not specified, we will automatically generate a file starting with ``early_stop`` based on the current time. Examples -------- Below gives a demo for a fake training process. >>> import torch >>> import torch.nn as nn >>> from torch.nn import MSELoss >>> from torch.optim import Adam >>> from dgllife.utils import EarlyStopping >>> model = nn.Linear(1, 1) >>> criterion = MSELoss() >>> # For MSE, the lower, the better >>> stopper = EarlyStopping(mode='lower', filename='test.pth') >>> optimizer = Adam(params=model.parameters(), lr=1e-3) >>> for epoch in range(1000): >>> x = torch.randn(1, 1) # Fake input >>> y = torch.randn(1, 1) # Fake label >>> pred = model(x) >>> loss = criterion(y, pred) >>> optimizer.zero_grad() >>> loss.backward() >>> optimizer.step() >>> early_stop = stopper.step(loss.detach().data, model) >>> if early_stop: >>> break >>> # Load the final parameters saved by the model >>> stopper.load_checkpoint(model) """ def __init__(self, mode='higher', patience=10, filename=None): if filename is None: dt = datetime.datetime.now() filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format( dt.date(), dt.hour, dt.minute, dt.second) assert mode in ['higher', 'lower'] self.mode = mode if self.mode == 'higher': self._check = self._check_higher else: self._check = self._check_lower self.patience = patience self.counter = 0 self.filename = filename self.best_score = None self.early_stop = False def _check_higher(self, score, prev_best_score): """Check if the new score is higher than the previous best score. Parameters ---------- score : float New score. prev_best_score : float Previous best score. Returns ------- bool Whether the new score is higher than the previous best score. """ return score > prev_best_score def _check_lower(self, score, prev_best_score): """Check if the new score is lower than the previous best score. Parameters ---------- score : float New score. prev_best_score : float Previous best score. Returns ------- bool Whether the new score is lower than the previous best score. """ return score < prev_best_score def step(self, score, model): """Update based on a new score. The new score is typically model performance on the validation set for a new epoch. Parameters ---------- score : float New score. model : nn.Module Model instance. Returns ------- bool Whether an early stop should be performed. """ if self.best_score is None: self.best_score = score self.save_checkpoint(model) elif self._check(score, self.best_score): self.best_score = score self.save_checkpoint(model) self.counter = 0 else: self.counter += 1 print( f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True return self.early_stop def save_checkpoint(self, model): '''Saves model when the metric on the validation set gets improved. Parameters ---------- model : nn.Module Model instance. ''' torch.save({'model_state_dict': model.state_dict()}, self.filename) def load_checkpoint(self, model): '''Load the latest checkpoint Parameters ---------- model : nn.Module Model instance. ''' model.load_state_dict(torch.load(self.filename)['model_state_dict'])