"tests/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "7730a79fcdd5fddbb8c695b3944b692392d8db91"
Commit 7c7634f6 authored by Myle Ott's avatar Myle Ott
Browse files

Support --warmup-updates with fixed LR schedule

parent 0daba38e
...@@ -16,16 +16,22 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -16,16 +16,22 @@ class FixedSchedule(FairseqLRScheduler):
def __init__(self, args, optimizer): def __init__(self, args, optimizer):
super().__init__(args, optimizer) super().__init__(args, optimizer)
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer.optimizer, self.anneal) self.lr = args.lr[0]
if args.warmup_updates > 0:
self.warmup_factor = 1. / args.warmup_updates
else:
self.warmup_factor = 1
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch') help='force annealing at specified epoch')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
def anneal(self, epoch): def get_next_lr(self, epoch):
lrs = self.args.lr lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal: if self.args.force_anneal is None or epoch < self.args.force_anneal:
# use fixed LR schedule # use fixed LR schedule
...@@ -33,10 +39,18 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -33,10 +39,18 @@ class FixedSchedule(FairseqLRScheduler):
else: else:
# annneal based on lr_shrink # annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR return next_lr
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch.""" """Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss) super().step(epoch, val_loss)
self.lr_scheduler.step(epoch) self.lr = self.get_next_lr(epoch)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates <= self.args.warmup_updates:
self.warmup_factor = num_updates / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr() return self.optimizer.get_lr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment