"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f90a5139a244d440b9a85b419a52cc0155f32f98"
Commit 04ecc834 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

changing class name AnnealingLR to OptimizerParamScheduler

parent 53931b8b
...@@ -577,13 +577,13 @@ def _add_learning_rate_args(parser): ...@@ -577,13 +577,13 @@ def _add_learning_rate_args(parser):
group.add_argument('--min-lr', type=float, default=0.0, group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler' help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.') 'clip values below this threshold.')
group.add_argument('--override-lr-scheduler', action='store_true', group.add_argument('--override-opt_param-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,' help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum ' 'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input ' 'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note' 'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.') 'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true', group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler ' help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning ' '(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style ' 'rate, maximum number of iterations, and decay style '
......
...@@ -167,7 +167,7 @@ def get_rng_state(): ...@@ -167,7 +167,7 @@ def get_rng_state():
return rng_state_list return rng_state_list
def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
"""Save a model checkpoint.""" """Save a model checkpoint."""
args = get_args() args = get_args()
...@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if not args.no_save_optim: if not args.no_save_optim:
if optimizer is not None: if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict() state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None: if opt_param_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict() state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# RNG states. # RNG states.
if not args.no_save_rng: if not args.no_save_rng:
...@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version): ...@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for" print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version)) " checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True): def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration. """Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of :attr:`state_dict` of the checkpoint match the names of
...@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try: try:
if optimizer is not None: if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer']) optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None: if opt_param_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler']) if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
except KeyError: except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent ' 'Specify --no-load-optim or --finetune to prevent '
......
...@@ -19,14 +19,14 @@ import math ...@@ -19,14 +19,14 @@ import math
from megatron import print_rank_0 from megatron import print_rank_0
class AnnealingLR(object): class OptimizerParamScheduler(object):
"""Anneals the learning rate.""" """Anneals the learning rate."""
def __init__(self, optimizer, max_lr, min_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style, warmup_steps, decay_steps, decay_style,
start_wd, end_wd, wd_incr_style, start_wd, end_wd, wd_incr_style,
use_checkpoint_lr_scheduler=True, use_checkpoint_opt_param_scheduler=True,
override_lr_scheduler=False): override_opt_param_scheduler=False):
# Class values. # Class values.
self.optimizer = optimizer self.optimizer = optimizer
...@@ -51,10 +51,10 @@ class AnnealingLR(object): ...@@ -51,10 +51,10 @@ class AnnealingLR(object):
self.wd_incr_style = wd_incr_style self.wd_incr_style = wd_incr_style
self.override_lr_scheduler = override_lr_scheduler self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_lr_scheduler: if self.override_opt_param_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\ assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate # Set the learning rate
...@@ -147,11 +147,11 @@ class AnnealingLR(object): ...@@ -147,11 +147,11 @@ class AnnealingLR(object):
def _check_and_set(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and """Auxiliary function for checking the values in the checkpoint and
setting them.""" setting them."""
if self.override_lr_scheduler: if self.override_opt_param_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value return cls_value
if not self.use_checkpoint_lr_scheduler: if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, \ assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \ f'AnnealingLR: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match' f'value {sd_value} for {name} do not match'
......
...@@ -43,7 +43,7 @@ from megatron.model import ModelType ...@@ -43,7 +43,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
...@@ -118,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -118,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start() timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
model_type) model_type)
timers('model-and-optimizer-setup').stop() timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' print_datetime('after model, optimizer, and learning rate '
...@@ -149,7 +149,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -149,7 +149,7 @@ def pretrain(train_valid_test_dataset_provider,
iteration = 0 iteration = 0
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator, train_data_iterator, valid_data_iterator,
process_non_loss_data_func) process_non_loss_data_func)
print_datetime('after training is done') print_datetime('after training is done')
...@@ -162,7 +162,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -162,7 +162,7 @@ def pretrain(train_valid_test_dataset_provider,
False) False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
if args.do_test: if args.do_test:
# Run on test data. # Run on test data.
...@@ -304,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -304,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return model return model
def get_learning_rate_scheduler(optimizer): def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler.""" """Build the learning rate scheduler."""
args = get_args() args = get_args()
...@@ -334,7 +334,7 @@ def get_learning_rate_scheduler(optimizer): ...@@ -334,7 +334,7 @@ def get_learning_rate_scheduler(optimizer):
raise Exception( raise Exception(
'either train-iters or train-samples should be provided.') 'either train-iters or train-samples should be provided.')
lr_scheduler = AnnealingLR( opt_param_scheduler = OptimizerParamScheduler(
optimizer, optimizer,
max_lr=args.lr, max_lr=args.lr,
min_lr=args.min_lr, min_lr=args.min_lr,
...@@ -344,10 +344,10 @@ def get_learning_rate_scheduler(optimizer): ...@@ -344,10 +344,10 @@ def get_learning_rate_scheduler(optimizer):
start_wd=args.start_weight_decay, start_wd=args.start_weight_decay,
end_wd=args.end_weight_decay, end_wd=args.end_weight_decay,
wd_incr_style=args.weight_decay_incr_style, wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_lr_scheduler=args.override_lr_scheduler) override_opt_param_scheduler=args.override_opt_param_scheduler)
return lr_scheduler return opt_param_scheduler
def setup_model_and_optimizer(model_provider_func, def setup_model_and_optimizer(model_provider_func,
...@@ -365,7 +365,7 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -365,7 +365,7 @@ def setup_model_and_optimizer(model_provider_func,
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond, optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult) scale_lr_cond, lr_mult)
lr_scheduler = get_learning_rate_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None: if args.load is not None:
timers = get_timers() timers = get_timers()
...@@ -373,7 +373,7 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -373,7 +373,7 @@ def setup_model_and_optimizer(model_provider_func,
# max time. # max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').start() timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').stop() timers('load-checkpoint').stop()
timers.log(['load-checkpoint']) timers.log(['load-checkpoint'])
...@@ -392,11 +392,11 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -392,11 +392,11 @@ def setup_model_and_optimizer(model_provider_func,
if args.fp16: if args.fp16:
optimizer.reload_model_params() optimizer.reload_model_params()
return model, optimizer, lr_scheduler return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, opt_param_scheduler):
"""Single training step.""" """Single training step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -472,7 +472,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -472,7 +472,7 @@ def train_step(forward_step_func, data_iterator,
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
args.micro_batch_size * \ args.micro_batch_size * \
args.data_parallel_size args.data_parallel_size
lr_scheduler.step(increment=increment) opt_param_scheduler.step(increment=increment)
skipped_iter = 0 skipped_iter = 0
else: else:
skipped_iter = 1 skipped_iter = 1
...@@ -662,19 +662,19 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -662,19 +662,19 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return report_memory_flag return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers() timers = get_timers()
# Extra barrier is added to make sure # Extra barrier is added to make sure
# all ranks report the max time. # all ranks report the max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('save-checkpoint').start() timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('save-checkpoint').stop() timers('save-checkpoint').stop()
timers.log(['save-checkpoint']) timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator, train_data_iterator, valid_data_iterator,
process_non_loss_data_func): process_non_loss_data_func):
"""Train the model function.""" """Train the model function."""
...@@ -704,7 +704,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -704,7 +704,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler) opt_param_scheduler)
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \ args.micro_batch_size * \
...@@ -725,7 +725,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -725,7 +725,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer, check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
...@@ -742,14 +742,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -742,14 +742,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
signal_handler = get_signal_handler() signal_handler = get_signal_handler()
if any(signal_handler.signals_received()): if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
print_datetime('exiting program after receiving SIGTERM.') print_datetime('exiting program after receiving SIGTERM.')
sys.exit() sys.exit()
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
saved_checkpoint = True saved_checkpoint = True
# Exiting based on duration # Exiting based on duration
...@@ -763,7 +763,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -763,7 +763,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if done: if done:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time)) print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit() sys.exit()
...@@ -771,7 +771,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -771,7 +771,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration)) print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
......
...@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model, def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler): optimizer, opt_param_scheduler):
"""Check for autoresume signal and exit if it is received.""" """Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
...@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model, ...@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
torch.distributed.barrier() torch.distributed.barrier()
if autoresume.termination_requested(): if autoresume.termination_requested():
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
print_rank_0(">>> autoresume termination request found!") print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
autoresume.request_resume() autoresume.request_resume()
......
...@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset, ...@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
return train_dataloader, valid_dataloader return train_dataloader, valid_dataloader
def _train(model, optimizer, lr_scheduler, forward_step, def _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback): train_dataloader, valid_dataloader, end_of_epoch_callback):
"""Train the model.""" """Train the model."""
args = get_args() args = get_args()
...@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler) out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1 iteration += 1
...@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler) optimizer, opt_param_scheduler)
# Checkpointing # Checkpointing
saved_checkpoint = False saved_checkpoint = False
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
saved_checkpoint = True saved_checkpoint = True
# Evaluation # Evaluation
...@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration)) print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch. # Callback at the end of each epoch.
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
...@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type) model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop() timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
...@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Finetune the model. # Finetune the model.
if args.epochs > 0: if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step, _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback) train_dataloader, valid_dataloader, end_of_epoch_callback)
# Or just evaluate. # Or just evaluate.
else: else:
......
...@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): ...@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def _train( def _train(
model, model,
optimizer, optimizer,
lr_scheduler, opt_param_scheduler,
forward_step, forward_step,
train_dataloader, train_dataloader,
valid_dataloader, valid_dataloader,
...@@ -179,7 +179,7 @@ def _train( ...@@ -179,7 +179,7 @@ def _train(
# Train for one step. # Train for one step.
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, lr_scheduler forward_step, batch, model, optimizer, opt_param_scheduler
) )
iteration += 1 iteration += 1
...@@ -206,7 +206,7 @@ def _train( ...@@ -206,7 +206,7 @@ def _train(
iteration % args.adlr_autoresume_interval == 0 iteration % args.adlr_autoresume_interval == 0
): ):
check_adlr_autoresume_termination( check_adlr_autoresume_termination(
iteration, model, optimizer, lr_scheduler iteration, model, optimizer, opt_param_scheduler
) )
# Checkpointing # Checkpointing
...@@ -215,7 +215,7 @@ def _train( ...@@ -215,7 +215,7 @@ def _train(
and args.save_interval and args.save_interval
and iteration % args.save_interval == 0 and iteration % args.save_interval == 0
): ):
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0: if args.eval_interval and iteration % args.eval_interval == 0:
...@@ -231,7 +231,7 @@ def _train( ...@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch. # Callback at the end of each epoch.
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
...@@ -266,7 +266,7 @@ def finetune( ...@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start() timers("model and optimizer").start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider)
timers("model and optimizer").stop() timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
...@@ -300,7 +300,7 @@ def finetune( ...@@ -300,7 +300,7 @@ def finetune(
_train( _train(
model, model,
optimizer, optimizer,
lr_scheduler, opt_param_scheduler,
forward_step, forward_step,
train_dataloader, train_dataloader,
valid_dataloader, valid_dataloader,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment