Commit 60e16a35 authored by Tatiana Likhomanenko's avatar Tatiana Likhomanenko Committed by Facebook Github Bot
Browse files

Fix warmup for fixed_schedule in case of first update (#1408)

Summary:
I faced the error while using warmup for fixed lr schedule

```
Traceback (most recent call last):
  File "/private/home/antares/.conda/envs/fairseq-20190809/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/private/home/antares/work/unsupervised/blank_test/fairseq-py/train.py", line 291, in distributed_main
    main(args, init_distributed=True)
  File "/private/home/antares/work/unsupervised/blank_test/fairseq-py/train.py", line 81, in main
    train(args, trainer, task, epoch_itr)
  File "/private/home/antares/work/unsupervised/blank_test/fairseq-py/train.py", line 122, in train
    log_output = trainer.train_step(samples)
  File "/private/home/antares/work/unsupervised/blank_test/fairseq-py/fairseq/trainer.py", line 409, in train_step
    self.optimizer.step()
  File "/private/home/antares/work/unsupervised/blank_test/fairseq-py/fairseq/optim/fp16_optimizer.py", line 153, in step
    self.fp32_optimizer.step(closure)
  File "/private/home/antares/work/unsupervised/blank_test/fairseq-py/fairseq/optim/fairseq_optimizer.py", line 98, in step
    self.optimizer.step(closure)
  File "/private/home/antares/work/unsupervised/blank_test/fairseq-py/fairseq/optim/nag.py", line 68, in step
    lr_correct = lr / lr_old
ZeroDivisionError: float division by zero
```
which is due to `num_updates=0` for the first iteration and thus `lr` we set to the optimizer is zero.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1408

Differential Revision: D18637526

Pulled By: myleott

fbshipit-source-id: fdd81dd69b1b38bc21a4fa315b4e25cee03af6bf
parent 226c1f48
......@@ -53,7 +53,7 @@ class FixedSchedule(FairseqLRScheduler):
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
self.warmup_factor = num_updates / float(self.args.warmup_updates)
if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates:
self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment