import torch class EarlyStopping: def __init__(self, patience: int = -1, checkpoint_path: str = 'checkpoint.pt'): self.patience = patience self.checkpoint_path = checkpoint_path self.counter = 0 self.best_score = None self.early_stop = False def step(self, acc, model): score = acc if self.best_score is None: self.best_score = score self.save_checkpoint(model) elif score < self.best_score: self.counter += 1 print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(model) self.counter = 0 return self.early_stop def save_checkpoint(self, model): '''Save model when validation loss decreases.''' torch.save(model.state_dict(), self.checkpoint_path) def load_checkpoint(self, model): model.load_state_dict(torch.load(self.checkpoint_path))