"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "a091fadfdaa2e4d6a34cf8bbfe4012913367a35a"
Commit ba9f32cc authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

add warmup support back to cosine lr sched (important for mt)

parent 8d6665f2
...@@ -13,16 +13,18 @@ from . import FairseqLRScheduler, register_lr_scheduler ...@@ -13,16 +13,18 @@ from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('cosine') @register_lr_scheduler('cosine')
class CosineSchedule(FairseqLRScheduler): class CosineSchedule(FairseqLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function. """Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details See https://arxiv.org/pdf/1608.03983.pdf for details
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`).
During warmup:
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where where
t_curr is current percentage of updates within the current period range t_curr is current percentage of updates within the current period range
t_i is the current period range, which is scaled by t_mul after every iteration t_i is the current period range, which is scaled by t_mul after every iteration
""" """
def __init__(self, args, optimizer): def __init__(self, args, optimizer):
...@@ -33,22 +35,38 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -33,22 +35,38 @@ class CosineSchedule(FairseqLRScheduler):
' Consider --lr-scheduler=fixed instead.' ' Consider --lr-scheduler=fixed instead.'
) )
self.min_lr = args.lr[0] warmup_end_lr = args.max_lr
if args.warmup_init_lr < 0:
args.warmup_init_lr = args.lr[0]
self.min_lr = args.lr[0]
self.max_lr = args.max_lr self.max_lr = args.max_lr
assert self.max_lr > self.min_lr, 'max_lr must be more than lr' assert self.max_lr > self.min_lr, 'max_lr must be more than lr'
self.t_mult = args.t_mult self.t_mult = args.t_mult
self.period = args.lr_period_updates self.period = args.lr_period_updates
if args.warmup_updates > 0:
# linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
else:
self.lr_step = 1
self.warmup_updates = args.warmup_updates
self.lr_shrink = args.lr_shrink self.lr_shrink = args.lr_shrink
# initial learning rate # initial learning rate
self.lr = self.max_lr self.lr = args.warmup_init_lr
self.optimizer.set_lr(self.lr) self.optimizer.set_lr(self.lr)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
parser.add_argument('--max-lr', required=True, type=float, metavar='LR', parser.add_argument('--max-lr', required=True, type=float, metavar='LR',
help='max learning rate, must be more than args.lr') help='max learning rate, must be more than args.lr')
parser.add_argument('--t-mult', default=1, type=float, metavar='LR', parser.add_argument('--t-mult', default=1, type=float, metavar='LR',
...@@ -64,20 +82,24 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -64,20 +82,24 @@ class CosineSchedule(FairseqLRScheduler):
def step_update(self, num_updates): def step_update(self, num_updates):
"""Update the learning rate after each update.""" """Update the learning rate after each update."""
if self.t_mult != 1: if num_updates < self.args.warmup_updates:
i = math.floor(math.log(1 - num_updates / self.period * (1 - self.t_mult), self.t_mult)) self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
t_i = self.t_mult ** i * self.period
t_curr = num_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period
else: else:
i = math.floor(num_updates / self.period) curr_updates = num_updates - self.args.warmup_updates
t_i = self.period if self.t_mult != 1:
t_curr = num_updates - (self.period * i) i = math.floor(math.log(1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult))
t_i = self.t_mult ** i * self.period
lr_shrink = self.lr_shrink ** i t_curr = curr_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period
min_lr = self.min_lr * lr_shrink else:
max_lr = self.max_lr * lr_shrink i = math.floor(curr_updates / self.period)
t_i = self.period
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) t_curr = num_updates - (self.period * i)
lr_shrink = self.lr_shrink ** i
min_lr = self.min_lr * lr_shrink
max_lr = self.max_lr * lr_shrink
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))
self.optimizer.set_lr(self.lr) self.optimizer.set_lr(self.lr)
return self.lr return self.lr
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment