Commit 2210fa71 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix initial learning rate (#453)

Summary:
There was a very subtle bug here 😢When we recently removed this line (7633129b), it meant that the learning rate scheduler didn't get initialized until after the first update. Unfortunately pytorch optimizers store the learning rate in their internal state, so some learning rate schedulers use their `__init__` method to reset the learning rate to some sane initial value. This is especially problematic for LR schedulers that include a warmup, where the Optimizer is likely to contain the peak learning rate at initialization, and it's only in the LR scheduler's `__init__` that the (much smaller) warmup value is set.

For example, the inverse_sqrt scheduler resets the learning rate upon initialization:
https://github.com/pytorch/fairseq/blob/7853818c2e33a63ec17a31bcfe20e4fc75d94130/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py#L48-L50

**Impact:** For the last ~1.5 weeks, the first training update would use the optimizer's default learning rate instead of the initial rate set by the LR scheduler. All subsequent updates used the correct learning rates. This primarily affects LR schedulers with warmups.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/453

Differential Revision: D13704453

Pulled By: myleott

fbshipit-source-id: a946da30100f837c66bdc6b9b77b014ab4eb8764
parent 7853818c
...@@ -92,7 +92,7 @@ class Trainer(object): ...@@ -92,7 +92,7 @@ class Trainer(object):
@property @property
def lr_scheduler(self): def lr_scheduler(self):
if self._lr_scheduler is None: if self._lr_scheduler is None:
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) self._build_optimizer() # this will initialize self._lr_scheduler
return self._lr_scheduler return self._lr_scheduler
def _build_optimizer(self): def _build_optimizer(self):
...@@ -110,6 +110,10 @@ class Trainer(object): ...@@ -110,6 +110,10 @@ class Trainer(object):
print('| NOTICE: your device may support faster training with --fp16') print('| NOTICE: your device may support faster training with --fp16')
self._optimizer = optim.build_optimizer(self.args, params) self._optimizer = optim.build_optimizer(self.args, params)
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment