import torch from torch.optim.optimizer import Optimizer class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler): """Linear learning rate scheduler with warm up.""" def __init__( self, optimizer: Optimizer, warmup_updates: int, max_updates: int, last_epoch: int = -1, verbose: bool = False, ): self.warmup_updates = warmup_updates self.max_updates = max_updates super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) def get_lr(self): if self._step_count <= self.warmup_updates: return [self._step_count / self.warmup_updates * base_lr for base_lr in self.base_lrs] elif self._step_count >= self.max_updates: return [0.0 for _ in self.base_lrs] else: pct_remaining = (self.max_updates - self._step_count) / (self.max_updates - self.warmup_updates) return [base_lr * pct_remaining for base_lr in self.base_lrs]