Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
294e81c1
Commit
294e81c1
authored
Jul 05, 2021
by
zihanl
Browse files
update training.py
parent
e57a8f74
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
95 additions
and
96 deletions
+95
-96
megatron/training.py
megatron/training.py
+95
-96
No files found.
megatron/training.py
View file @
294e81c1
...
@@ -138,61 +138,60 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -138,61 +138,60 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0
(
'training ...'
)
print_rank_0
(
'training ...'
)
iteration
=
0
iteration
=
0
if
not
args
.
run_dialog
:
# if not args.run_dialog:
# original pre-training for GPT
if
args
.
do_train
and
args
.
train_iters
>
0
:
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
train_data_iterator
,
valid_data_iterator
)
print_datetime
(
'after training is done'
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
False
)
if
args
.
save
and
iteration
!=
0
:
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
if
args
.
do_test
:
if
args
.
do_test
:
# Run on test data.
# Run on test data.
prefix
=
'the end of training for test data'
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
test_data_iterator
,
model
,
0
,
True
)
0
,
True
)
else
:
#
else:
# training for dialog/control model
#
# training for dialog/control model
timers
(
'interval-time'
).
start
()
# start timers('interval-time') here to avoid it from starting multiple times
#
timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times
for
e
in
range
(
args
.
num_epoch
):
#
for e in range(args.num_epoch):
print_rank_0
(
'> training on epoch %d'
%
(
e
+
1
))
#
print_rank_0('> training on epoch %d' % (e+1))
if
args
.
do_train
and
args
.
train_iters
>
0
:
#
if args.do_train and args.train_iters > 0:
iteration
+=
train
(
forward_step_func
,
#
iteration += train(forward_step_func,
model
,
optimizer
,
lr_scheduler
,
#
model, optimizer, lr_scheduler,
train_data_iterator
,
valid_data_iterator
)
#
train_data_iterator, valid_data_iterator)
print_datetime
(
'after training is done'
)
#
print_datetime('after training is done')
if
args
.
do_valid
:
#
if args.do_valid:
prefix
=
'the end of training for val data'
#
prefix = 'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
#
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator
,
model
,
#
valid_data_iterator, model,
iteration
,
False
)
#
iteration, False)
# if args.train_module == "dialog":
#
# if args.train_module == "dialog":
# if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0:
#
# if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0:
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
#
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
if
args
.
train_module
==
"control"
:
#
if args.train_module == "control":
if
(
e
+
1
)
>=
5
and
(
e
+
1
)
<=
9
and
args
.
save
and
iteration
!=
0
:
#
if (e+1) >= 5 and (e+1) <= 9 and args.save and iteration != 0:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
#
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if
args
.
do_test
:
#
if args.do_test:
# Run on test data.
#
# Run on test data.
prefix
=
'the end of training for test data'
#
prefix = 'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
#
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator
,
model
,
#
test_data_iterator, model,
0
,
True
)
#
0, True)
def
update_train_iters
(
args
):
def
update_train_iters
(
args
):
...
@@ -645,8 +644,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -645,8 +644,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
# Iterations.
iteration
=
args
.
iteration
iteration
=
args
.
iteration
if
not
args
.
run_dialog
:
#
if not args.run_dialog:
timers
(
'interval-time'
).
start
()
timers
(
'interval-time'
).
start
()
print_datetime
(
'before the start of training step'
)
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
report_memory_flag
=
True
...
@@ -829,51 +828,51 @@ def build_train_valid_test_data_iterators(
...
@@ -829,51 +828,51 @@ def build_train_valid_test_data_iterators(
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
*
args
.
global_batch_size
args
.
eval_iters
*
args
.
global_batch_size
if
args
.
run_dialog
:
#
if args.run_dialog:
args
.
consumed_train_samples
=
0
#
args.consumed_train_samples = 0
args
.
consumed_valid_samples
=
0
#
args.consumed_valid_samples = 0
args
.
iteration
=
0
#
args.iteration = 0
# Data loader only on rank 0 of each model parallel group.
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
args
.
run_dialog
:
# if args.run_dialog:
# Build the datasets.
# # Build the datasets.
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets_provider
()
# train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider()
print_rank_0
(
' > datasets target sizes:'
)
# print_rank_0(' > datasets target sizes:')
train_size
=
len
(
train_ds
)
# train_size = len(train_ds)
valid_size
=
len
(
valid_ds
)
# valid_size = len(valid_ds)
test_size
=
len
(
test_ds
)
# test_size = len(test_ds)
print_rank_0
(
' train: {}'
.
format
(
train_size
))
# print_rank_0(' train: {}'.format(train_size))
print_rank_0
(
' validation: {}'
.
format
(
valid_size
))
# print_rank_0(' validation: {}'.format(valid_size))
print_rank_0
(
' test: {}'
.
format
(
test_size
))
# print_rank_0(' test: {}'.format(test_size))
batch_size
=
args
.
global_batch_size
# batch_size = args.global_batch_size
args
.
train_iters
=
train_size
//
batch_size
+
1
# args.train_iters = train_size // batch_size + 1
args
.
eval_iters
=
valid_size
//
batch_size
+
1
# args.eval_iters = valid_size // batch_size + 1
args
.
test_iters
=
test_size
//
batch_size
+
1
# args.test_iters = test_size // batch_size + 1
# else:
# Number of train/valid/test samples.
if
args
.
train_samples
:
train_samples
=
args
.
train_samples
else
:
else
:
# Number of train/valid/test samples.
train_samples
=
args
.
train_iters
*
args
.
global_batch_size
if
args
.
train_samples
:
eval_iters
=
(
args
.
train_iters
//
args
.
eval_interval
+
1
)
*
\
train_samples
=
args
.
train_samples
args
.
eval_iters
else
:
test_iters
=
args
.
eval_iters
train_samples
=
args
.
train_iters
*
args
.
global_batch_size
train_val_test_num_samples
=
[
train_samples
,
eval_iters
=
(
args
.
train_iters
//
args
.
eval_interval
+
1
)
*
\
eval_iters
*
args
.
global_batch_size
,
args
.
eval_iters
test_iters
*
args
.
global_batch_size
]
test_iters
=
args
.
eval_iters
print_rank_0
(
' > datasets target sizes (minimum size):'
)
train_val_test_num_samples
=
[
train_samples
,
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
eval_iters
*
args
.
global_batch_size
,
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
test_iters
*
args
.
global_batch_size
]
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
# Build the datasets.
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets_provider
(
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
train_val_test_num_samples
)
# Build the datasets.
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets_provider
(
train_val_test_num_samples
)
# Build dataloders.
# Build dataloders.
train_dataloader
=
build_pretraining_data_loader
(
train_dataloader
=
build_pretraining_data_loader
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment