Commit b1a6d73b authored by zihanl's avatar zihanl
Browse files

fix training.py

parent 6fd0b406
...@@ -141,7 +141,6 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -141,7 +141,6 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
# if not args.run_dialog:
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
...@@ -355,8 +354,6 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -355,8 +354,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').start() timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
# need to set train_samples to None
args.train_samples = None
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').stop() timers('load-checkpoint').stop()
timers.log(['load-checkpoint']) timers.log(['load-checkpoint'])
...@@ -662,9 +659,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -662,9 +659,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
# if not args.run_dialog:
timers('interval-time').start() timers('interval-time').start()
print_datetime('before the start of training step') print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment