#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # Copied from https://github.com/facebookresearch/dlrm/blob/mlperf/dlrm_s_pytorch.py import sys from torch.optim.lr_scheduler import _LRScheduler class LRPolicyScheduler(_LRScheduler): def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps): self.num_warmup_steps = num_warmup_steps self.decay_start_step = decay_start_step self.decay_end_step = decay_start_step + num_decay_steps self.num_decay_steps = num_decay_steps if self.decay_start_step < self.num_warmup_steps: sys.exit("Learning rate warmup must finish before the decay starts") super(LRPolicyScheduler, self).__init__(optimizer) def get_lr(self): step_count = self._step_count if step_count < self.num_warmup_steps: # warmup scale = 1.0 - (self.num_warmup_steps - step_count) / self.num_warmup_steps lr = [base_lr * scale for base_lr in self.base_lrs] self.last_lr = lr elif self.decay_start_step <= step_count and step_count < self.decay_end_step: # decay decayed_steps = step_count - self.decay_start_step scale = ((self.num_decay_steps - decayed_steps) / self.num_decay_steps) ** 2 min_lr = 0.0000001 lr = [max(min_lr, base_lr * scale) for base_lr in self.base_lrs] self.last_lr = lr else: if self.num_decay_steps > 0: # freeze at last, either because we're after decay # or because we're between warmup and decay lr = self.last_lr else: # do not adjust lr = self.base_lrs return lr