Commit b1a6d73b authored by zihanl's avatar zihanl
Browse files

fix training.py

parent 6fd0b406
......@@ -141,7 +141,6 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...')
iteration = 0
# if not args.run_dialog:
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
......@@ -163,7 +162,7 @@ def pretrain(train_valid_test_dataset_provider,
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
......@@ -355,8 +354,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
torch.distributed.barrier()
timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
# need to set train_samples to None
args.train_samples = None
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])
......@@ -662,9 +659,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration = args.iteration
# if not args.run_dialog:
timers('interval-time').start()
print_datetime('before the start of training step')
report_memory_flag = True
while iteration < args.train_iters:
......@@ -860,7 +855,7 @@ def build_train_valid_test_data_iterators(
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment