Commit 294e81c1 authored by zihanl's avatar zihanl
Browse files

update training.py

parent e57a8f74
...@@ -138,61 +138,60 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -138,61 +138,60 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
if not args.run_dialog: # if not args.run_dialog:
# original pre-training for GPT if args.do_train and args.train_iters > 0:
if args.do_train and args.train_iters > 0: iteration = train(forward_step_func,
iteration = train(forward_step_func, model, optimizer, lr_scheduler,
model, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator)
train_data_iterator, valid_data_iterator) print_datetime('after training is done')
print_datetime('after training is done')
if args.do_valid:
if args.do_valid: prefix = 'the end of training for val data'
prefix = 'the end of training for val data' evaluate_and_print_results(prefix, forward_step_func,
evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model,
valid_data_iterator, model, iteration, False)
iteration, False)
if args.save and iteration != 0:
if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
if args.do_test: # Run on test data.
# Run on test data. prefix = 'the end of training for test data'
prefix = 'the end of training for test data' evaluate_and_print_results(prefix, forward_step_func,
evaluate_and_print_results(prefix, forward_step_func, test_data_iterator, model,
test_data_iterator, model, 0, True)
0, True)
else: # else:
# training for dialog/control model # # training for dialog/control model
timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times # timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times
for e in range(args.num_epoch): # for e in range(args.num_epoch):
print_rank_0('> training on epoch %d' % (e+1)) # print_rank_0('> training on epoch %d' % (e+1))
if args.do_train and args.train_iters > 0: # if args.do_train and args.train_iters > 0:
iteration += train(forward_step_func, # iteration += train(forward_step_func,
model, optimizer, lr_scheduler, # model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator) # train_data_iterator, valid_data_iterator)
print_datetime('after training is done') # print_datetime('after training is done')
if args.do_valid: # if args.do_valid:
prefix = 'the end of training for val data' # prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, # evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, # valid_data_iterator, model,
iteration, False) # iteration, False)
# if args.train_module == "dialog": # # if args.train_module == "dialog":
# if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0: # # if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0:
# save_checkpoint(iteration, model, optimizer, lr_scheduler) # # save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.train_module == "control": # if args.train_module == "control":
if (e+1) >= 5 and (e+1) <= 9 and args.save and iteration != 0: # if (e+1) >= 5 and (e+1) <= 9 and args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) # save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test: # if args.do_test:
# Run on test data. # # Run on test data.
prefix = 'the end of training for test data' # prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func, # evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model, # test_data_iterator, model,
0, True) # 0, True)
def update_train_iters(args): def update_train_iters(args):
...@@ -645,8 +644,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -645,8 +644,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
if not args.run_dialog: # if not args.run_dialog:
timers('interval-time').start() timers('interval-time').start()
print_datetime('before the start of training step') print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
...@@ -829,51 +828,51 @@ def build_train_valid_test_data_iterators( ...@@ -829,51 +828,51 @@ def build_train_valid_test_data_iterators(
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size args.eval_iters * args.global_batch_size
if args.run_dialog: # if args.run_dialog:
args.consumed_train_samples = 0 # args.consumed_train_samples = 0
args.consumed_valid_samples = 0 # args.consumed_valid_samples = 0
args.iteration = 0 # args.iteration = 0
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if args.run_dialog: # if args.run_dialog:
# Build the datasets. # # Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider() # train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider()
print_rank_0(' > datasets target sizes:') # print_rank_0(' > datasets target sizes:')
train_size = len(train_ds) # train_size = len(train_ds)
valid_size = len(valid_ds) # valid_size = len(valid_ds)
test_size = len(test_ds) # test_size = len(test_ds)
print_rank_0(' train: {}'.format(train_size)) # print_rank_0(' train: {}'.format(train_size))
print_rank_0(' validation: {}'.format(valid_size)) # print_rank_0(' validation: {}'.format(valid_size))
print_rank_0(' test: {}'.format(test_size)) # print_rank_0(' test: {}'.format(test_size))
batch_size = args.global_batch_size # batch_size = args.global_batch_size
args.train_iters = train_size // batch_size + 1 # args.train_iters = train_size // batch_size + 1
args.eval_iters = valid_size // batch_size + 1 # args.eval_iters = valid_size // batch_size + 1
args.test_iters = test_size // batch_size + 1 # args.test_iters = test_size // batch_size + 1
# else:
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else: else:
# Number of train/valid/test samples. train_samples = args.train_iters * args.global_batch_size
if args.train_samples: eval_iters = (args.train_iters // args.eval_interval + 1) * \
train_samples = args.train_samples args.eval_iters
else: test_iters = args.eval_iters
train_samples = args.train_iters * args.global_batch_size train_val_test_num_samples = [train_samples,
eval_iters = (args.train_iters // args.eval_interval + 1) * \ eval_iters * args.global_batch_size,
args.eval_iters test_iters * args.global_batch_size]
test_iters = args.eval_iters print_rank_0(' > datasets target sizes (minimum size):')
train_val_test_num_samples = [train_samples, print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
eval_iters * args.global_batch_size, print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
test_iters * args.global_batch_size] print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) # Build the datasets.
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) train_val_test_num_samples)
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
# Build dataloders. # Build dataloders.
train_dataloader = build_pretraining_data_loader( train_dataloader = build_pretraining_data_loader(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment