Commit b1a6d73b authored by zihanl's avatar zihanl
Browse files

fix training.py

parent 6fd0b406
......@@ -141,7 +141,6 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...')
iteration = 0
# if not args.run_dialog:
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
......@@ -355,8 +354,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
torch.distributed.barrier()
timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
# need to set train_samples to None
args.train_samples = None
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])
......@@ -662,9 +659,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration = args.iteration
# if not args.run_dialog:
timers('interval-time').start()
print_datetime('before the start of training step')
report_memory_flag = True
while iteration < args.train_iters:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment