Commit 04ecc834 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

changing class name AnnealingLR to OptimizerParamScheduler

parent 53931b8b
......@@ -577,13 +577,13 @@ def _add_learning_rate_args(parser):
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--override-lr-scheduler', action='store_true',
group.add_argument('--override-opt_param-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
......
......@@ -167,7 +167,7 @@ def get_rng_state():
return rng_state_list
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
"""Save a model checkpoint."""
args = get_args()
......@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
......@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
......@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
......
......@@ -19,14 +19,14 @@ import math
from megatron import print_rank_0
class AnnealingLR(object):
class OptimizerParamScheduler(object):
"""Anneals the learning rate."""
def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style,
start_wd, end_wd, wd_incr_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
use_checkpoint_opt_param_scheduler=True,
override_opt_param_scheduler=False):
# Class values.
self.optimizer = optimizer
......@@ -51,10 +51,10 @@ class AnnealingLR(object):
self.wd_incr_style = wd_incr_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
'use-checkpoint are set.'
# Set the learning rate
......@@ -147,11 +147,11 @@ class AnnealingLR(object):
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_lr_scheduler:
if self.override_opt_param_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value
if not self.use_checkpoint_lr_scheduler:
if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match'
......
......@@ -43,7 +43,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
......@@ -118,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
model_type)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
......@@ -149,7 +149,7 @@ def pretrain(train_valid_test_dataset_provider,
iteration = 0
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func)
print_datetime('after training is done')
......@@ -162,7 +162,7 @@ def pretrain(train_valid_test_dataset_provider,
False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
if args.do_test:
# Run on test data.
......@@ -304,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return model
def get_learning_rate_scheduler(optimizer):
def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
......@@ -334,7 +334,7 @@ def get_learning_rate_scheduler(optimizer):
raise Exception(
'either train-iters or train-samples should be provided.')
lr_scheduler = AnnealingLR(
opt_param_scheduler = OptimizerParamScheduler(
optimizer,
max_lr=args.lr,
min_lr=args.min_lr,
......@@ -344,10 +344,10 @@ def get_learning_rate_scheduler(optimizer):
start_wd=args.start_weight_decay,
end_wd=args.end_weight_decay,
wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler)
return lr_scheduler
return opt_param_scheduler
def setup_model_and_optimizer(model_provider_func,
......@@ -365,7 +365,7 @@ def setup_model_and_optimizer(model_provider_func,
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
lr_scheduler = get_learning_rate_scheduler(optimizer)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None:
timers = get_timers()
......@@ -373,7 +373,7 @@ def setup_model_and_optimizer(model_provider_func,
# max time.
torch.distributed.barrier()
timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])
......@@ -392,11 +392,11 @@ def setup_model_and_optimizer(model_provider_func,
if args.fp16:
optimizer.reload_model_params()
return model, optimizer, lr_scheduler
return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
model, optimizer, opt_param_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
......@@ -472,7 +472,7 @@ def train_step(forward_step_func, data_iterator,
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
lr_scheduler.step(increment=increment)
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
......@@ -662,19 +662,19 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
timers('save-checkpoint').stop()
timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler,
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
"""Train the model function."""
......@@ -704,7 +704,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator,
model,
optimizer,
lr_scheduler)
opt_param_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
......@@ -725,7 +725,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
......@@ -742,14 +742,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
signal_handler = get_signal_handler()
if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
print_datetime('exiting program after receiving SIGTERM.')
sys.exit()
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
saved_checkpoint = True
# Exiting based on duration
......@@ -763,7 +763,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()
......@@ -771,7 +771,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit()
......
......@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler):
optimizer, opt_param_scheduler):
"""Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
......@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
torch.distributed.barrier()
if autoresume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
autoresume.request_resume()
......
......@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
return train_dataloader, valid_dataloader
def _train(model, optimizer, lr_scheduler, forward_step,
def _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback):
"""Train the model."""
args = get_args()
......@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0
# Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1
......@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler)
optimizer, opt_param_scheduler)
# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
saved_checkpoint = True
# Evaluation
......@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit()
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
......@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for
......@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Finetune the model.
if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step,
_train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback)
# Or just evaluate.
else:
......
......@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def _train(
model,
optimizer,
lr_scheduler,
opt_param_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
......@@ -179,7 +179,7 @@ def _train(
# Train for one step.
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, lr_scheduler
forward_step, batch, model, optimizer, opt_param_scheduler
)
iteration += 1
......@@ -206,7 +206,7 @@ def _train(
iteration % args.adlr_autoresume_interval == 0
):
check_adlr_autoresume_termination(
iteration, model, optimizer, lr_scheduler
iteration, model, optimizer, opt_param_scheduler
)
# Checkpointing
......@@ -215,7 +215,7 @@ def _train(
and args.save_interval
and iteration % args.save_interval == 0
):
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
......@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
......@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider)
timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for
......@@ -300,7 +300,7 @@ def finetune(
_train(
model,
optimizer,
lr_scheduler,
opt_param_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment