Commit 294e81c1 authored by zihanl's avatar zihanl
Browse files

update training.py

parent e57a8f74
......@@ -138,61 +138,60 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...')
iteration = 0
if not args.run_dialog:
# original pre-training for GPT
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
# if not args.run_dialog:
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
else:
# training for dialog/control model
timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times
for e in range(args.num_epoch):
print_rank_0('> training on epoch %d' % (e+1))
if args.do_train and args.train_iters > 0:
iteration += train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
# if args.train_module == "dialog":
# if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0:
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.train_module == "control":
if (e+1) >= 5 and (e+1) <= 9 and args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
# else:
# # training for dialog/control model
# timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times
# for e in range(args.num_epoch):
# print_rank_0('> training on epoch %d' % (e+1))
# if args.do_train and args.train_iters > 0:
# iteration += train(forward_step_func,
# model, optimizer, lr_scheduler,
# train_data_iterator, valid_data_iterator)
# print_datetime('after training is done')
# if args.do_valid:
# prefix = 'the end of training for val data'
# evaluate_and_print_results(prefix, forward_step_func,
# valid_data_iterator, model,
# iteration, False)
# # if args.train_module == "dialog":
# # if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0:
# # save_checkpoint(iteration, model, optimizer, lr_scheduler)
# if args.train_module == "control":
# if (e+1) >= 5 and (e+1) <= 9 and args.save and iteration != 0:
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
# if args.do_test:
# # Run on test data.
# prefix = 'the end of training for test data'
# evaluate_and_print_results(prefix, forward_step_func,
# test_data_iterator, model,
# 0, True)
def update_train_iters(args):
......@@ -645,8 +644,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration = args.iteration
if not args.run_dialog:
timers('interval-time').start()
# if not args.run_dialog:
timers('interval-time').start()
print_datetime('before the start of training step')
report_memory_flag = True
......@@ -829,51 +828,51 @@ def build_train_valid_test_data_iterators(
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
if args.run_dialog:
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.iteration = 0
# if args.run_dialog:
# args.consumed_train_samples = 0
# args.consumed_valid_samples = 0
# args.iteration = 0
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
if args.run_dialog:
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider()
print_rank_0(' > datasets target sizes:')
train_size = len(train_ds)
valid_size = len(valid_ds)
test_size = len(test_ds)
print_rank_0(' train: {}'.format(train_size))
print_rank_0(' validation: {}'.format(valid_size))
print_rank_0(' test: {}'.format(test_size))
batch_size = args.global_batch_size
args.train_iters = train_size // batch_size + 1
args.eval_iters = valid_size // batch_size + 1
args.test_iters = test_size // batch_size + 1
# if args.run_dialog:
# # Build the datasets.
# train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider()
# print_rank_0(' > datasets target sizes:')
# train_size = len(train_ds)
# valid_size = len(valid_ds)
# test_size = len(test_ds)
# print_rank_0(' train: {}'.format(train_size))
# print_rank_0(' validation: {}'.format(valid_size))
# print_rank_0(' test: {}'.format(test_size))
# batch_size = args.global_batch_size
# args.train_iters = train_size // batch_size + 1
# args.eval_iters = valid_size // batch_size + 1
# args.test_iters = test_size // batch_size + 1
# else:
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
# Build dataloders.
train_dataloader = build_pretraining_data_loader(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment