Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a31833ce
Commit
a31833ce
authored
Dec 11, 2020
by
mshoeybi
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Fix some bugs, add exit-duration capability
parent
51315905
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
158 additions
and
42 deletions
+158
-42
megatron/__init__.py
megatron/__init__.py
+1
-0
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+17
-5
megatron/data/gpt2_dataset.py
megatron/data/gpt2_dataset.py
+15
-6
megatron/global_vars.py
megatron/global_vars.py
+7
-2
megatron/microbatches.py
megatron/microbatches.py
+22
-16
megatron/training.py
megatron/training.py
+94
-13
No files found.
megatron/__init__.py
View file @
a31833ce
...
...
@@ -26,6 +26,7 @@ from .package_info import (
)
from
.global_vars
import
get_args
from
.global_vars
import
get_current_global_batch_size
from
.global_vars
import
get_num_microbatches
from
.global_vars
import
update_num_microbatches
from
.global_vars
import
get_tokenizer
...
...
megatron/arguments.py
View file @
a31833ce
...
...
@@ -326,6 +326,8 @@ def _add_training_args(parser):
group
.
add_argument
(
'--exit-interval'
,
type
=
int
,
default
=
None
,
help
=
'Exit the program after the iteration is divisible '
'by this value.'
)
group
.
add_argument
(
'--exit-duration-in-mins'
,
type
=
int
,
default
=
None
,
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
...
...
megatron/data/dataset_utils.py
View file @
a31833ce
...
...
@@ -418,11 +418,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
datasets_train_valid_test_num_samples
[
i
],
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
dataset_type
)
# Blend.
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
if
train_ds
:
train_datasets
.
append
(
train_ds
)
if
valid_ds
:
valid_datasets
.
append
(
valid_ds
)
if
test_ds
:
test_datasets
.
append
(
test_ds
)
# Blend.
blending_train_dataset
=
None
if
train_datasets
:
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_valid_dataset
=
None
if
valid_datasets
:
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_test_dataset
=
None
if
test_datasets
:
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
return
(
blending_train_dataset
,
blending_valid_dataset
,
blending_test_dataset
)
...
...
megatron/data/gpt2_dataset.py
View file @
a31833ce
...
...
@@ -55,14 +55,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes
[
i
],
data_impl
,
splits_string
,
datasets_train_valid_test_num_samples
[
i
],
seq_length
,
seed
,
skip_warmup
)
train_datasets
.
append
(
train_ds
)
valid_datasets
.
append
(
valid_ds
)
test_datasets
.
append
(
test_ds
)
if
train_ds
:
train_datasets
.
append
(
train_ds
)
if
valid_ds
:
valid_datasets
.
append
(
valid_ds
)
if
test_ds
:
test_datasets
.
append
(
test_ds
)
# Blend.
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
blending_train_dataset
=
None
if
train_datasets
:
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_valid_dataset
=
None
if
valid_datasets
:
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_test_dataset
=
None
if
test_datasets
:
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
return
(
blending_train_dataset
,
blending_valid_dataset
,
blending_test_dataset
)
...
...
megatron/global_vars.py
View file @
a31833ce
...
...
@@ -43,8 +43,13 @@ def get_num_microbatches():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
update_num_microbatches
(
consumed_samples
):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
)
def
get_current_global_batch_size
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
def
update_num_microbatches
(
consumed_samples
,
consistency_check
=
True
):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
consistency_check
)
def
get_tokenizer
():
...
...
megatron/microbatches.py
View file @
a31833ce
...
...
@@ -56,12 +56,16 @@ class NumMicroBatchesCalculator(ABC):
def
__init__
(
self
):
self
.
num_micro_batches
=
None
self
.
current_global_batch_size
=
None
def
get
(
self
):
return
self
.
num_micro_batches
def
get_current_global_batch_size
(
self
):
return
self
.
current_global_batch_size
@
abstractmethod
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
,
consistency_check
):
pass
...
...
@@ -78,8 +82,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
self
.
num_micro_batches
=
global_batch_size
//
\
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
self
.
current_global_batch_size
=
global_batch_size
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
,
consistency_check
):
pass
...
...
@@ -128,24 +133,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
)
self
.
update
(
0
,
False
)
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
,
consistency_check
):
if
consumed_samples
>
self
.
ramup_samples
:
current_global_batch_size
=
self
.
global_batch_size
self
.
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
assert
current_global_batch_size
<=
self
.
global_batch_size
assert
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
num_micro_batches
=
current_global_batch_size
//
\
self
.
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
assert
self
.
current_global_batch_size
<=
self
.
global_batch_size
if
consistency_check
:
assert
self
.
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
self
.
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
num_micro_batches
=
self
.
current_global_batch_size
//
\
self
.
micro_batch_times_data_parallel_size
megatron/training.py
View file @
a31833ce
...
...
@@ -18,6 +18,10 @@
from
datetime
import
datetime
import
math
import
sys
import
time
# The earliest we can measure the start time.
_TRAIN_START_TIME
=
time
.
time
()
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
...
...
@@ -25,6 +29,7 @@ from apex.optimizers import FusedAdam as Adam
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_current_global_batch_size
from
megatron
import
get_num_microbatches
from
megatron
import
update_num_microbatches
from
megatron
import
mpu
...
...
@@ -44,6 +49,13 @@ from megatron.data.data_loaders import build_pretraining_data_loader
from
megatron.utils
import
report_memory
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
torch
.
distributed
.
barrier
()
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
print_rank_0
(
'['
+
string
+
'] datetime: {} '
.
format
(
time_str
))
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
"""Main training program.
...
...
@@ -74,6 +86,18 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global
_TRAIN_START_TIME
start_time_tensor
=
torch
.
cuda
.
FloatTensor
([
_TRAIN_START_TIME
])
torch
.
distributed
.
all_reduce
(
start_time_tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
)
_TRAIN_START_TIME
=
start_time_tensor
.
item
()
print_rank_0
(
'time took to initialize megatron (seconds): {:.3f}'
.
format
(
time
.
time
()
-
_TRAIN_START_TIME
))
print_datetime
(
'after megatron is initialized'
)
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -81,6 +105,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model and optimizer'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
# Data stuff.
timers
(
'train/valid/test data iterators'
).
start
()
...
...
@@ -88,6 +114,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
=
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
timers
(
'train/valid/test data iterators'
).
stop
()
print_datetime
(
'after dataloaders are build'
)
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
...
...
@@ -99,6 +126,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
...
...
@@ -132,13 +160,11 @@ def update_train_iters(args):
consumed_samples
=
0
# Rampup phase.
while
consumed_samples
<=
int
(
args
.
rampup_batch_size
[
2
]):
update_num_microbatches
(
consumed_samples
)
consumed_samples
+=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
update_num_microbatches
(
consumed_samples
,
consistency_check
=
False
)
consumed_samples
+=
get_current_global_batch_size
()
iterations
+=
1
# Reset
update_num_microbatches
(
0
)
update_num_microbatches
(
0
,
consistency_check
=
False
)
# Constant phase
# Note that we throw away any partial last batch.
iterations
+=
(
args
.
train_samples
-
consumed_samples
)
//
\
...
...
@@ -267,7 +293,15 @@ def setup_model_and_optimizer(model_provider_func):
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
timers
=
get_timers
()
# Extra barrier is added to make sure all ranks report the
# max time.
torch
.
distributed
.
barrier
()
timers
(
'load checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'load checkpoint'
).
stop
()
timers
.
log
([
'load checkpoint'
])
else
:
args
.
iteration
=
0
...
...
@@ -685,11 +719,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Tensorboard values.
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
writer
.
add_scalar
(
'learning_rate'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning_rate-iterations'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning_rate-samples'
,
learning_rate
,
args
.
consumed_train_samples
)
batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
*
\
get_num_microbatches
()
writer
.
add_scalar
(
'batch_size-iterations'
,
batch_size
,
iteration
)
writer
.
add_scalar
(
'batch_size-samples'
,
batch_size
,
args
.
consumed_train_samples
)
for
key
in
loss_dict
:
writer
.
add_scalar
(
key
,
loss_dict
[
key
],
iteration
)
writer
.
add_scalar
(
key
,
loss_dict
[
key
]
+
'-iterations'
,
iteration
)
writer
.
add_scalar
(
key
,
loss_dict
[
key
]
+
'-samples'
,
args
.
consumed_train_samples
)
if
args
.
fp16
:
writer
.
add_scalar
(
'loss_scale'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss_scale-iterations'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss_scale-samples'
,
loss_scale
,
args
.
consumed_train_samples
)
normalizer
=
iteration
%
args
.
log_interval
if
normalizer
==
0
:
normalizer
=
args
.
log_interval
...
...
@@ -703,6 +748,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
elapsed_time
/
args
.
log_interval
,
iteration
)
log_string
=
' iteration {:8d}/{:8d} |'
.
format
(
iteration
,
args
.
train_iters
)
log_string
+=
' consumed samples {:12d} |'
.
format
(
args
.
consumed_train_samples
)
log_string
+=
' elapsed time per iteration (ms): {:.1f} |'
.
format
(
elapsed_time
*
1000.0
/
args
.
log_interval
)
log_string
+=
' learning rate: {:.3E} |'
.
format
(
learning_rate
)
...
...
@@ -732,6 +779,18 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return
report_memory_flag
def
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
timers
=
get_timers
()
# Extra barrier is added to make sure
# all ranks report the max time.
torch
.
distributed
.
barrier
()
timers
(
'save checkpoint'
).
start
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'save checkpoint'
).
stop
()
timers
.
log
([
'save checkpoint'
])
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
):
"""Train the model function."""
...
...
@@ -748,6 +807,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration
=
args
.
iteration
timers
(
'interval time'
).
start
()
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
update_num_microbatches
(
args
.
consumed_train_samples
)
...
...
@@ -777,9 +837,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler
)
# Checkpointing
saved_checkpoint
=
False
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
saved_checkpoint
=
True
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
...
...
@@ -789,14 +853,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
valid_data_iterator
,
model
,
iteration
,
False
)
# Exiting based on duration
if
args
.
exit_duration_in_mins
:
train_time
=
(
time
.
time
()
-
_TRAIN_START_TIME
)
/
60.0
done_cuda
=
torch
.
cuda
.
IntTensor
(
[
train_time
>
args
.
exit_duration_in_mins
])
torch
.
distributed
.
all_reduce
(
done_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
done
=
done_cuda
.
item
()
if
done
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
print_datetime
(
'exiting program after {} minutes'
.
format
(
train_time
))
sys
.
exit
()
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
rank
=
torch
.
distributed
.
get_rank
()
print_rank_0
(
'rank: {} | time: {} | exiting the program at '
'iteration {}'
.
format
(
rank
,
time_str
,
iteration
))
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
return
iteration
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment