Commit 294e81c1 authored by zihanl's avatar zihanl
Browse files

update training.py

parent e57a8f74
...@@ -138,8 +138,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -138,8 +138,7 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
if not args.run_dialog: # if not args.run_dialog:
# original pre-training for GPT
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
...@@ -162,37 +161,37 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -162,37 +161,37 @@ def pretrain(train_valid_test_dataset_provider,
test_data_iterator, model, test_data_iterator, model,
0, True) 0, True)
else: # else:
# training for dialog/control model # # training for dialog/control model
timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times # timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times
for e in range(args.num_epoch): # for e in range(args.num_epoch):
print_rank_0('> training on epoch %d' % (e+1)) # print_rank_0('> training on epoch %d' % (e+1))
if args.do_train and args.train_iters > 0: # if args.do_train and args.train_iters > 0:
iteration += train(forward_step_func, # iteration += train(forward_step_func,
model, optimizer, lr_scheduler, # model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator) # train_data_iterator, valid_data_iterator)
print_datetime('after training is done') # print_datetime('after training is done')
if args.do_valid: # if args.do_valid:
prefix = 'the end of training for val data' # prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, # evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, # valid_data_iterator, model,
iteration, False) # iteration, False)
# if args.train_module == "dialog": # # if args.train_module == "dialog":
# if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0: # # if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0:
# # save_checkpoint(iteration, model, optimizer, lr_scheduler)
# if args.train_module == "control":
# if (e+1) >= 5 and (e+1) <= 9 and args.save and iteration != 0:
# save_checkpoint(iteration, model, optimizer, lr_scheduler) # save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.train_module == "control":
if (e+1) >= 5 and (e+1) <= 9 and args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test: # if args.do_test:
# Run on test data. # # Run on test data.
prefix = 'the end of training for test data' # prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func, # evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model, # test_data_iterator, model,
0, True) # 0, True)
def update_train_iters(args): def update_train_iters(args):
...@@ -645,7 +644,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -645,7 +644,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
if not args.run_dialog: # if not args.run_dialog:
timers('interval-time').start() timers('interval-time').start()
print_datetime('before the start of training step') print_datetime('before the start of training step')
...@@ -829,32 +828,32 @@ def build_train_valid_test_data_iterators( ...@@ -829,32 +828,32 @@ def build_train_valid_test_data_iterators(
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size args.eval_iters * args.global_batch_size
if args.run_dialog: # if args.run_dialog:
args.consumed_train_samples = 0 # args.consumed_train_samples = 0
args.consumed_valid_samples = 0 # args.consumed_valid_samples = 0
args.iteration = 0 # args.iteration = 0
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if args.run_dialog: # if args.run_dialog:
# Build the datasets. # # Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider() # train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider()
print_rank_0(' > datasets target sizes:') # print_rank_0(' > datasets target sizes:')
train_size = len(train_ds) # train_size = len(train_ds)
valid_size = len(valid_ds) # valid_size = len(valid_ds)
test_size = len(test_ds) # test_size = len(test_ds)
print_rank_0(' train: {}'.format(train_size)) # print_rank_0(' train: {}'.format(train_size))
print_rank_0(' validation: {}'.format(valid_size)) # print_rank_0(' validation: {}'.format(valid_size))
print_rank_0(' test: {}'.format(test_size)) # print_rank_0(' test: {}'.format(test_size))
batch_size = args.global_batch_size # batch_size = args.global_batch_size
args.train_iters = train_size // batch_size + 1 # args.train_iters = train_size // batch_size + 1
args.eval_iters = valid_size // batch_size + 1 # args.eval_iters = valid_size // batch_size + 1
args.test_iters = test_size // batch_size + 1 # args.test_iters = test_size // batch_size + 1
else: # else:
# Number of train/valid/test samples. # Number of train/valid/test samples.
if args.train_samples: if args.train_samples:
train_samples = args.train_samples train_samples = args.train_samples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment