Commit 1eda0a17 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'refactor_learning_rate' into 'blendable_dataset'

Refactor learning rate so it is easier to make learning rate based on consumed samples

See merge request ADLR/megatron-lm!179
parents 98989693 fa80af26
...@@ -280,7 +280,7 @@ def _add_learning_rate_args(parser): ...@@ -280,7 +280,7 @@ def _add_learning_rate_args(parser):
'and initial warmup, the learing rate at each ' 'and initial warmup, the learing rate at each '
'iteration would be different.') 'iteration would be different.')
group.add_argument('--lr-decay-style', type=str, default='linear', group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine', 'exponential'], choices=['constant', 'linear', 'cosine'],
help='Learning rate decay function.') help='Learning rate decay function.')
group.add_argument('--lr-decay-iters', type=int, default=None, group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,' help='number of iterations to decay learning rate over,'
......
...@@ -19,77 +19,101 @@ import math ...@@ -19,77 +19,101 @@ import math
from megatron import print_rank_0 from megatron import print_rank_0
class AnnealingLR(object): class AnnealingLR(object):
"""Anneals the learning rate.""" """Anneals the learning rate."""
def __init__(self, optimizer, start_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_iter, total_iters, warmup_steps, decay_steps,
decay_style, last_iter, min_lr=0.0, decay_style, num_steps,
use_checkpoint_lr_scheduler=True, use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False): override_lr_scheduler=False):
# Class values. # Class values.
self.optimizer = optimizer self.optimizer = optimizer
self.start_lr = start_lr
self.max_lr = float(max_lr)
self.min_lr = min_lr self.min_lr = min_lr
self.warmup_iter = warmup_iter assert self.min_lr >= 0.0
self.num_iters = last_iter assert self.max_lr >= self.min_lr
self.end_iter = total_iters
assert self.end_iter > 0 self.warmup_steps = warmup_steps
self.num_steps = num_steps
self.decay_steps = decay_steps
assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps
self.decay_style = decay_style self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler: if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\ assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate # Set the learning rate
self.step(self.num_iters) self.step(step_num=self.num_steps)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self): def get_lr(self):
"""Learning rate decay functions from: """Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) # Use linear warmup for the initial part.
# Warmup. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: return self.max_lr * float(self.num_steps) / \
return float(self.start_lr) * num_iters_ / self.warmup_iter float(self.warmup_steps)
# If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant':
return self.max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps:
return self.min_lr
# If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr
num_iters_ = num_iters_ - self.warmup_iter
if self.decay_style == 'linear': if self.decay_style == 'linear':
lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter coeff = (1.0 - decay_ratio)
elif self.decay_style == 'cosine': elif self.decay_style == 'cosine':
lr = self.start_lr / 2.0 * (math.cos( coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
math.pi * num_iters_ / self.end_iter) + 1)
elif self.decay_style == 'exponential':
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter)
else: else:
lr = self.start_lr raise Exception('{} decay style is not supported.'.format(
return max(lr, self.min_lr) self.decay_style))
return self.min_lr + coeff * delta_lr
def step(self, step_num=None):
def step(self, increment=1, step_num=None):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
if step_num is None: if step_num is None:
step_num = self.num_iters + 1 step_num = self.num_steps + increment
self.num_iters = step_num self.num_steps = step_num
new_lr = self.get_lr() new_lr = self.get_lr()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
'start_lr': self.start_lr, 'max_lr': self.max_lr,
'warmup_iter': self.warmup_iter, 'warmup_steps': self.warmup_steps,
'num_iters': self.num_iters, 'num_steps': self.num_steps,
'decay_style': self.decay_style, 'decay_style': self.decay_style,
'end_iter': self.end_iter, 'decay_steps': self.decay_steps,
'min_lr': self.min_lr 'min_lr': self.min_lr
} }
return state_dict return state_dict
def _check_and_set(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and """Auxiliary function for checking the values in the checkpoint and
setting them.""" setting them."""
...@@ -104,20 +128,39 @@ class AnnealingLR(object): ...@@ -104,20 +128,39 @@ class AnnealingLR(object):
name)) name))
return sd_value return sd_value
def load_state_dict(self, sd): def load_state_dict(self, sd):
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], if 'start_lr' in sd:
'learning rate') max_lr_ = sd['start_lr']
else:
max_lr_ = sd['max_lr']
self.max_lr = self._check_and_set(self.max_lr, max_lr_,
'learning rate')
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate') 'minimum learning rate')
self.warmup_iter = self._check_and_set(self.warmup_iter,
sd['warmup_iter'], if 'warmup_iter' in sd:
'warmup iterations') warmup_steps_ = sd['warmup_iter']
self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'], else:
'total number of iterations') warmup_steps_ = sd['warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps,
warmup_steps_,
'warmup iterations')
if 'end_iter' in sd:
decay_steps_ = sd['end_iter']
else:
decay_steps_ = sd['decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_,
'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style, self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'], sd['decay_style'],
'decay style') 'decay style')
self.num_iters = sd['num_iters'] if 'num_iters' in sd:
self.step(self.num_iters) self.num_steps = sd['num_iters']
else:
self.num_steps = sd['num_steps']
self.step(step_num=self.num_steps)
...@@ -194,12 +194,12 @@ def get_learning_rate_scheduler(optimizer): ...@@ -194,12 +194,12 @@ def get_learning_rate_scheduler(optimizer):
warmup_iter = args.warmup * num_iters warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR( lr_scheduler = AnnealingLR(
optimizer, optimizer,
start_lr=args.lr, max_lr=args.lr,
warmup_iter=warmup_iter,
total_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step,
min_lr=args.min_lr, min_lr=args.min_lr,
warmup_steps=warmup_iter,
decay_steps=num_iters,
decay_style=args.lr_decay_style,
num_steps=init_step,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler) override_lr_scheduler=args.override_lr_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment