Commit b1a6d73b authored by zihanl's avatar zihanl
Browse files

fix training.py

parent 6fd0b406
...@@ -141,7 +141,6 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -141,7 +141,6 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
# if not args.run_dialog:
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
...@@ -163,7 +162,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -163,7 +162,7 @@ def pretrain(train_valid_test_dataset_provider,
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model, test_data_iterator, model,
0, True) 0, True)
def update_train_iters(args): def update_train_iters(args):
# For iteration-based training, we don't need to do anything # For iteration-based training, we don't need to do anything
...@@ -355,8 +354,6 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -355,8 +354,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').start() timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
# need to set train_samples to None
args.train_samples = None
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').stop() timers('load-checkpoint').stop()
timers.log(['load-checkpoint']) timers.log(['load-checkpoint'])
...@@ -662,9 +659,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -662,9 +659,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
# if not args.run_dialog:
timers('interval-time').start() timers('interval-time').start()
print_datetime('before the start of training step') print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
...@@ -860,7 +855,7 @@ def build_train_valid_test_data_iterators( ...@@ -860,7 +855,7 @@ def build_train_valid_test_data_iterators(
else: else:
train_samples = args.train_iters * args.global_batch_size train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \ eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters args.eval_iters
test_iters = args.eval_iters test_iters = args.eval_iters
train_val_test_num_samples = [train_samples, train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size, eval_iters * args.global_batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment