Commit 8defa9d9 authored by Yilei Li's avatar Yilei Li Committed by Facebook Github Bot
Browse files

Add warmup support in reduce_on_plateau lr schedule

Summary:
Enables reduce_on_plateau schedule with optional warmup phase, where we linearly increase the learning rate from some initial learning rate (``--warmup-init-lr``) until the configured learning rate (``--lr``). Thereafter the lr is adjusted according to original reduce_on_plateau scheme
During warmup::

      lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
      lr = lrs[update_num]

Reviewed By: yqwangustc

Differential Revision: D17779925

fbshipit-source-id: c3bfb3321c76850824fc42df4fac4e5dcf73fbf8
parent e49b302a
...@@ -10,7 +10,17 @@ from . import FairseqLRScheduler, register_lr_scheduler ...@@ -10,7 +10,17 @@ from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('reduce_lr_on_plateau') @register_lr_scheduler('reduce_lr_on_plateau')
class ReduceLROnPlateau(FairseqLRScheduler): class ReduceLROnPlateau(FairseqLRScheduler):
"""Decay the LR by a factor every time the validation loss plateaus.""" """
Decay the LR by a factor every time the validation loss plateaus.
Also comes with optional warmup phase, where we linearly increase the learning rate
from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (``--lr``). Thereafter the lr is adjusted according to original reduce_on_plateau scheme
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
"""
def __init__(self, args, optimizer): def __init__(self, args, optimizer):
super().__init__(args, optimizer) super().__init__(args, optimizer)
...@@ -22,6 +32,20 @@ class ReduceLROnPlateau(FairseqLRScheduler): ...@@ -22,6 +32,20 @@ class ReduceLROnPlateau(FairseqLRScheduler):
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer, patience=0, factor=args.lr_shrink, self.optimizer.optimizer, patience=0, factor=args.lr_shrink,
threshold=args.lr_threshold) threshold=args.lr_threshold)
warmup_end_lr = args.lr[0]
"""if no warm up, sets initial lr to be args.lr[0]"""
if args.warmup_init_lr < 0:
args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr
""" linearly warmup for the first args.warmup_updates"""
if args.warmup_updates > 0:
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
""" this flag is either set from arg when no warm up, or set by step_update() when warmup finishes"""
self.warmup_end = True if args.warmup_updates <= 0 else False
""" initial learning rate"""
"""this self.lr is used only during init and/or warm up period"""
self.lr = args.warmup_init_lr
self.optimizer.set_lr(self.lr)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -32,6 +56,10 @@ class ReduceLROnPlateau(FairseqLRScheduler): ...@@ -32,6 +56,10 @@ class ReduceLROnPlateau(FairseqLRScheduler):
parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT', parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT',
help='Threshold for measuring the new optimum, \ help='Threshold for measuring the new optimum, \
to only focus on significant changes') to only focus on significant changes')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
# fmt: on # fmt: on
def state_dict(self): def state_dict(self):
...@@ -48,9 +76,23 @@ class ReduceLROnPlateau(FairseqLRScheduler): ...@@ -48,9 +76,23 @@ class ReduceLROnPlateau(FairseqLRScheduler):
self.lr_scheduler.last_epoch = state_dict['last_epoch'] self.lr_scheduler.last_epoch = state_dict['last_epoch']
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch.""" """Update the learning rate at the end of the given epoch if warmup finishes"""
if val_loss is not None: """ otherwise no update of lr on epoch boundaries"""
if val_loss is not None and self.warmup_end is True:
self.lr_scheduler.step(val_loss, epoch) self.lr_scheduler.step(val_loss, epoch)
else: else:
self.lr_scheduler.last_epoch = epoch self.lr_scheduler.last_epoch = epoch
return self.optimizer.get_lr() return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
""" if there is warmup"""
if self.args.warmup_updates > 0:
if num_updates <= self.args.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates*self.lr_step
self.optimizer.set_lr(self.lr)
else:
if self.warmup_end is False:
self.warmup_end = True
"""else do nothing """
return self.optimizer.get_lr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment