Commit 641408f5 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

more naming cleanup

parent 04ecc834
...@@ -13,18 +13,18 @@ ...@@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Learning rate decay functions.""" """Learning rate decay and weight decay incr functions."""
import math import math
from megatron import print_rank_0 from megatron import print_rank_0
class OptimizerParamScheduler(object): class OptimizerParamScheduler(object):
"""Anneals the learning rate.""" """Anneals learning rate and weight decay"""
def __init__(self, optimizer, max_lr, min_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style, lr_warmup_steps, lr_decay_steps, lr_decay_style,
start_wd, end_wd, wd_incr_style, start_wd, end_wd, wd_incr_steps, wd_incr_style,
use_checkpoint_opt_param_scheduler=True, use_checkpoint_opt_param_scheduler=True,
override_opt_param_scheduler=False): override_opt_param_scheduler=False):
...@@ -36,19 +36,19 @@ class OptimizerParamScheduler(object): ...@@ -36,19 +36,19 @@ class OptimizerParamScheduler(object):
assert self.min_lr >= 0.0 assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps self.lr_warmup_steps = lr_warmup_steps
self.num_steps = 0 self.num_steps = 0
self.decay_steps = decay_steps self.lr_decay_steps = lr_decay_steps
assert self.decay_steps > 0 assert self.lr_decay_steps > 0
assert self.warmup_steps < self.decay_steps assert self.lr_warmup_steps < self.lr_decay_steps
self.decay_style = decay_style self.lr_decay_style = lr_decay_style
self.start_wd = start_wd self.start_wd = start_wd
self.end_wd = end_wd self.end_wd = end_wd
assert self.start_wd >= 0.0 assert self.start_wd >= 0.0
assert self.end_wd >= self.start_wd assert self.end_wd >= self.start_wd
self.wd_incr_steps = wd_incr_steps
self.wd_incr_style = wd_incr_style self.wd_incr_style = wd_incr_style
self.override_opt_param_scheduler = override_opt_param_scheduler self.override_opt_param_scheduler = override_opt_param_scheduler
...@@ -59,26 +59,27 @@ class OptimizerParamScheduler(object): ...@@ -59,26 +59,27 @@ class OptimizerParamScheduler(object):
# Set the learning rate # Set the learning rate
self.step(0) self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
def get_wd(self): def get_wd(self):
if self.num_steps > self.decay_steps: """ Weight decay incr functions"""
if self.num_steps > self.wd_incr_steps:
return self.end_wd return self.end_wd
if self.wd_incr_style == 'constant': if self.wd_incr_style == 'constant':
assert self.start_wd == self.end_wd assert self.start_wd == self.end_wd
return self.end_wd return self.end_wd
decay_ratio = float(self.num_steps) / float(self.decay_steps) incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
assert decay_ratio >= 0.0 assert incr_ratio >= 0.0
assert decay_ratio <= 1.0 assert incr_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == 'linear': if self.wd_incr_style == 'linear':
coeff = decay_ratio coeff = incr_ratio
elif self.wd_incr_style == 'cosine': elif self.wd_incr_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * (1 - decay_ratio)) + 1.0) coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
else: else:
raise Exception('{} weight decay increment style is not supported.'.format( raise Exception('{} weight decay increment style is not supported.'.format(
self.wd_incr_style)) self.wd_incr_style))
...@@ -91,33 +92,33 @@ class OptimizerParamScheduler(object): ...@@ -91,33 +92,33 @@ class OptimizerParamScheduler(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part. # Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
return self.max_lr * float(self.num_steps) / \ return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps) float(self.lr_warmup_steps)
# If the learning rate is constant, just return the initial value. # If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant': if self.lr_decay_style == 'constant':
return self.max_lr return self.max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`. # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps: if self.num_steps > self.lr_decay_steps:
return self.min_lr return self.min_lr
# If we are done with the warmup period, use the decay style. # If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_) decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0 assert decay_ratio >= 0.0
assert decay_ratio <= 1.0 assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr delta_lr = self.max_lr - self.min_lr
if self.decay_style == 'linear': if self.lr_decay_style == 'linear':
coeff = (1.0 - decay_ratio) coeff = (1.0 - decay_ratio)
elif self.decay_style == 'cosine': elif self.lr_decay_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else: else:
raise Exception('{} decay style is not supported.'.format( raise Exception('{} decay style is not supported.'.format(
self.decay_style)) self.lr_decay_style))
return self.min_lr + coeff * delta_lr return self.min_lr + coeff * delta_lr
...@@ -135,11 +136,15 @@ class OptimizerParamScheduler(object): ...@@ -135,11 +136,15 @@ class OptimizerParamScheduler(object):
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
'max_lr': self.max_lr, 'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps, 'lr_warmup_steps': self.lr_warmup_steps,
'num_steps': self.num_steps, 'num_steps': self.num_steps,
'decay_style': self.decay_style, 'lr_decay_style': self.lr_decay_style,
'decay_steps': self.decay_steps, 'lr_decay_steps': self.lr_decay_steps,
'min_lr': self.min_lr 'min_lr': self.min_lr,
'start_wd': self.start_wd,
'end_wd': self.end_wd,
'wd_incr_style': self.wd_incr_style,
'wd_incr_steps': self.wd_incr_steps
} }
return state_dict return state_dict
...@@ -153,7 +158,7 @@ class OptimizerParamScheduler(object): ...@@ -153,7 +158,7 @@ class OptimizerParamScheduler(object):
if not self.use_checkpoint_opt_param_scheduler: if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, \ assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \ f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match' f'value {sd_value} for {name} do not match'
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name)) name))
...@@ -174,24 +179,56 @@ class OptimizerParamScheduler(object): ...@@ -174,24 +179,56 @@ class OptimizerParamScheduler(object):
if 'warmup_iter' in sd: if 'warmup_iter' in sd:
warmup_steps_ = sd['warmup_iter'] warmup_steps_ = sd['warmup_iter']
else: elif 'warmup_steps' in sd:
warmup_steps_ = sd['warmup_steps'] warmup_steps_ = sd['warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps, else:
warmup_steps_, lr_warmup_steps_ = sd['lr_warmup_steps']
self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
lr_warmup_steps_,
'warmup iterations') 'warmup iterations')
if 'end_iter' in sd: if 'end_iter' in sd:
decay_steps_ = sd['end_iter'] lr_decay_steps_ = sd['end_iter']
elif 'decay_steps' in sd:
lr_decay_steps_ = sd['decay_steps']
else: else:
decay_steps_ = sd['decay_steps'] lr_decay_steps_ = sd['lr_decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
'total number of iterations') 'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'], if 'decay_style' in sd:
'decay style') lr_decay_style_ = sd['decay_style']
else:
lr_decay_style_ = sd['lr_decay_style']
self.lr_decay_style = self._check_and_set(self.lr_decay_style,
lr_decay_style_,
'learning rate decay style')
if 'num_iters' in sd: if 'num_iters' in sd:
num_steps = sd['num_iters'] num_steps = sd['num_iters']
else: else:
num_steps = sd['num_steps'] num_steps = sd['num_steps']
self.step(increment=num_steps) self.step(increment=num_steps)
if 'start_wd' in sd:
self.start_wd = self._check_and_set(self.start_wd,
sd['start_wd'],
"start weight decay")
self.end_wd = self._check_and_set(self.end_wd,
sd['end_wd'],
"end weight decay")
self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
sd['wd_incr_steps'],
"total number of weight decay iterations")
self.wd_incr_style = self._check_and_set(self.wd_incr_style,
sd['wd_incr_style'],
"weight decay incr style")
...@@ -312,11 +312,12 @@ def get_optimizer_param_scheduler(optimizer): ...@@ -312,11 +312,12 @@ def get_optimizer_param_scheduler(optimizer):
if args.train_iters: if args.train_iters:
if args.lr_decay_iters is None: if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters args.lr_decay_iters = args.train_iters
decay_steps = args.lr_decay_iters * args.global_batch_size lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
if args.lr_warmup_fraction is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else: else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training. # Sample-based training.
elif args.train_samples: elif args.train_samples:
# We need to set training iters for later use. Technically # We need to set training iters for later use. Technically
...@@ -325,11 +326,12 @@ def get_optimizer_param_scheduler(optimizer): ...@@ -325,11 +326,12 @@ def get_optimizer_param_scheduler(optimizer):
update_train_iters(args) update_train_iters(args)
if args.lr_decay_samples is None: if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples args.lr_decay_samples = args.train_samples
decay_steps = args.lr_decay_samples lr_decay_steps = args.lr_decay_samples
wd_incr_steps = args.train_samples
if args.lr_warmup_fraction is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else: else:
warmup_steps = args.lr_warmup_samples lr_warmup_steps = args.lr_warmup_samples
else: else:
raise Exception( raise Exception(
'either train-iters or train-samples should be provided.') 'either train-iters or train-samples should be provided.')
...@@ -338,11 +340,12 @@ def get_optimizer_param_scheduler(optimizer): ...@@ -338,11 +340,12 @@ def get_optimizer_param_scheduler(optimizer):
optimizer, optimizer,
max_lr=args.lr, max_lr=args.lr,
min_lr=args.min_lr, min_lr=args.min_lr,
warmup_steps=warmup_steps, lr_warmup_steps=lr_warmup_steps,
decay_steps=decay_steps, lr_decay_steps=lr_decay_steps,
decay_style=args.lr_decay_style, lr_decay_style=args.lr_decay_style,
start_wd=args.start_weight_decay, start_wd=args.start_weight_decay,
end_wd=args.end_weight_decay, end_wd=args.end_weight_decay,
wd_incr_steps=wd_incr_steps,
wd_incr_style=args.weight_decay_incr_style, wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler, use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler) override_opt_param_scheduler=args.override_opt_param_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment