Commit dbed5e07 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

inverse_square_root learning param schedule

parent c7d57ff7
......@@ -649,7 +649,7 @@ def _add_learning_rate_args(parser):
'and initial warmup, the learing rate at each '
'iteration would be different.')
group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine'],
choices=['constant', 'linear', 'cosine', 'inverse-square-root'],
help='Learning rate decay function.')
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,'
......
......@@ -90,8 +90,14 @@ class OptimizerParamScheduler(object):
# For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
if self.num_steps > self.lr_decay_steps:
return self.min_lr
# If we are done with the warmup period, use the decay style.
if self.lr_decay_style == 'inverse-square-root':
warmup_steps = max(self.lr_warmup_steps, 1)
num_steps = max(self.num_steps, 1)
lr = self.max_lr * warmup_steps ** 0.5 / (num_steps ** 0.5)
return max(self.min_lr, lr)
num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment