Commit ba9f32cc authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

add warmup support back to cosine lr sched (important for mt)

parent 8d6665f2
......@@ -13,16 +13,18 @@ from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('cosine')
class CosineSchedule(FairseqLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`).
During warmup:
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where
t_curr is current percentage of updates within the current period range
t_i is the current period range, which is scaled by t_mul after every iteration
"""
def __init__(self, args, optimizer):
......@@ -33,22 +35,38 @@ class CosineSchedule(FairseqLRScheduler):
' Consider --lr-scheduler=fixed instead.'
)
self.min_lr = args.lr[0]
warmup_end_lr = args.max_lr
if args.warmup_init_lr < 0:
args.warmup_init_lr = args.lr[0]
self.min_lr = args.lr[0]
self.max_lr = args.max_lr
assert self.max_lr > self.min_lr, 'max_lr must be more than lr'
self.t_mult = args.t_mult
self.period = args.lr_period_updates
if args.warmup_updates > 0:
# linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
else:
self.lr_step = 1
self.warmup_updates = args.warmup_updates
self.lr_shrink = args.lr_shrink
# initial learning rate
self.lr = self.max_lr
self.lr = args.warmup_init_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
parser.add_argument('--max-lr', required=True, type=float, metavar='LR',
help='max learning rate, must be more than args.lr')
parser.add_argument('--t-mult', default=1, type=float, metavar='LR',
......@@ -64,20 +82,24 @@ class CosineSchedule(FairseqLRScheduler):
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.t_mult != 1:
i = math.floor(math.log(1 - num_updates / self.period * (1 - self.t_mult), self.t_mult))
t_i = self.t_mult ** i * self.period
t_curr = num_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period
if num_updates < self.args.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
else:
i = math.floor(num_updates / self.period)
t_i = self.period
t_curr = num_updates - (self.period * i)
lr_shrink = self.lr_shrink ** i
min_lr = self.min_lr * lr_shrink
max_lr = self.max_lr * lr_shrink
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))
curr_updates = num_updates - self.args.warmup_updates
if self.t_mult != 1:
i = math.floor(math.log(1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult))
t_i = self.t_mult ** i * self.period
t_curr = curr_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period
else:
i = math.floor(curr_updates / self.period)
t_i = self.period
t_curr = num_updates - (self.period * i)
lr_shrink = self.lr_shrink ** i
min_lr = self.min_lr * lr_shrink
max_lr = self.max_lr * lr_shrink
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))
self.optimizer.set_lr(self.lr)
return self.lr
return self.lr
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment