Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6ea23928
Commit
6ea23928
authored
Dec 06, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Add micro-batch size calculator
parent
9019bbf4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
149 additions
and
28 deletions
+149
-28
megatron/__init__.py
megatron/__init__.py
+2
-0
megatron/arguments.py
megatron/arguments.py
+47
-6
megatron/checkpointing.py
megatron/checkpointing.py
+2
-1
megatron/data/data_loaders.py
megatron/data/data_loaders.py
+1
-2
megatron/global_vars.py
megatron/global_vars.py
+76
-0
megatron/training.py
megatron/training.py
+21
-19
No files found.
megatron/__init__.py
View file @
6ea23928
...
...
@@ -26,6 +26,8 @@ from .package_info import (
)
from
.global_vars
import
get_args
from
.global_vars
import
get_num_microbatches
from
.global_vars
import
update_num_microbatches
from
.global_vars
import
get_tokenizer
from
.global_vars
import
get_tensorboard_writer
from
.global_vars
import
get_adlr_autoresume
...
...
megatron/arguments.py
View file @
6ea23928
...
...
@@ -54,18 +54,45 @@ def parse_args(extra_args_provider=None, defaults={},
# Distributed args.
args
.
rank
=
int
(
os
.
getenv
(
'RANK'
,
'0'
))
args
.
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
# Tensor model parallel size.
args
.
tensor_model_parallel_size
=
min
(
args
.
tensor_model_parallel_size
,
args
.
world_size
)
assert
args
.
world_size
%
args
.
tensor_model_parallel_size
==
0
,
'world size'
\
' ({}) is not divisible by tensor model parallel size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
)
# Pipeline model parallel size.
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
,
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
if
args
.
pipeline_model_parallel_size
>
1
:
if
"ring_exchange"
not
in
dir
(
torch
.
distributed
):
raise
Exception
(
'PyTorch with torch.distributed.ring_exchange needed '
'to run pipeline MP!'
)
raise
Exception
(
'PyTorch with torch.distributed.ring_exchange '
'needed to run pipeline MP!'
)
# Checks.
args
.
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
assert
args
.
world_size
%
args
.
model_parallel_size
==
0
,
'world size is not'
\
' divisible by tensor parallel size ({}) times pipeline paralle '
\
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
data_parallel_size
=
args
.
world_size
//
args
.
model_parallel_size
if
args
.
rank
==
0
:
print
(
'using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
))
print
(
'using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '
.
format
(
args
.
world_size
,
args
.
data_parallel_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
),
flush
=
True
)
# Batch size.
assert
args
.
micro_batch_size
is
not
None
assert
args
.
micro_batch_size
>
0
if
args
.
global_batch_size
is
None
:
args
.
global_batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
if
args
.
rank
==
0
:
print
(
'setting global batch size to {}'
.
format
(
args
.
global_batch_size
),
flush
=
True
)
assert
args
.
global_batch_size
>
0
# Fp16 loss scaling.
args
.
dynamic_loss_scale
=
False
...
...
@@ -214,8 +241,22 @@ def _add_training_args(parser):
help
=
'Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size.'
)
group
.
add_argument
(
'--num-microbatches'
,
type
=
int
,
default
=
1
,
help
=
'Number of microbatches in minibatch'
)
group
.
add_argument
(
'--global-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Training batch size. If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size'
)
group
.
add_argument
(
'--rampup-batch-size'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.'
)
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
...
...
megatron/checkpointing.py
View file @
6ea23928
...
...
@@ -23,7 +23,7 @@ import numpy as np
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
megatron
import
mpu
,
get_args
from
megatron
import
mpu
,
get_args
,
update_num_microbatches
from
megatron
import
get_args
from
megatron
import
print_rank_0
...
...
@@ -236,6 +236,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
check_checkpoint_args
(
checkpoint_args
)
args
.
consumed_train_samples
=
getattr
(
checkpoint_args
,
'consumed_train_samples'
,
0
)
update_num_microbatches
(
consumed_samples
=
args
.
consumed_train_samples
)
args
.
consumed_valid_samples
=
getattr
(
checkpoint_args
,
'consumed_valid_samples'
,
0
)
else
:
...
...
megatron/data/data_loaders.py
View file @
6ea23928
...
...
@@ -30,13 +30,12 @@ def build_pretraining_data_loader(dataset, consumed_samples):
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
global_batch_size
=
args
.
micro_batch_size
*
world_size
# Megatron sampler
batch_sampler
=
MegatronPretrainingSampler
(
total_samples
=
len
(
dataset
),
consumed_samples
=
consumed_samples
,
global_batch_size
=
global_batch_size
,
global_batch_size
=
args
.
global_batch_size
,
rank
=
mpu
.
get_data_parallel_rank
(),
world_size
=
world_size
)
...
...
megatron/global_vars.py
View file @
6ea23928
...
...
@@ -15,6 +15,8 @@
"""Megatron global variables."""
from
abc
import
ABC
from
abc
import
abstractmethod
import
os
import
sys
import
time
...
...
@@ -25,18 +27,35 @@ from megatron.tokenizer import build_tokenizer
from
.arguments
import
parse_args
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
_GLOBAL_TOKENIZER
=
None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
def
get_args
():
"""Return arguments."""
_ensure_var_is_initialized
(
_GLOBAL_ARGS
,
'args'
)
return
_GLOBAL_ARGS
def
get_num_microbatches_calculator
():
"""Return num-microbatches calculator."""
_ensure_var_is_initialized
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
,
'number of micro-batches calculator.'
)
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
def
get_num_microbatches
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
update_num_microbatches
(
consumed_samples
):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
)
def
get_tokenizer
():
"""Return tokenizer."""
_ensure_var_is_initialized
(
_GLOBAL_TOKENIZER
,
'tokenizer'
)
...
...
@@ -67,6 +86,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
args
=
_parse_args
(
extra_args_provider
=
extra_args_provider
,
defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
_build_num_microbatches_calculator
(
args
)
_
=
_build_tokenizer
(
args
)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
...
...
@@ -84,6 +104,62 @@ def _parse_args(extra_args_provider=None, defaults={},
return
_GLOBAL_ARGS
def
_build_num_microbatches_calculator
(
args
):
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
,
'num microbatches calculator'
)
# Constant num micro-batches.
if
args
.
rampup_batch_size
is
None
:
micro_batch_times_data_parallel
=
args
.
micro_batch_size
*
\
arg
.
data_parallel_size
assert
args
.
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
\
'global batch size ({}) is not divisible by micro batch size ({})'
\
' times data parallel size ({})'
.
format
(
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
num_micro_batches
=
args
.
global_batch_size
//
\
micro_batch_times_data_parallel
if
args
.
rank
==
0
:
print
(
'setting number of micro-batches to constant {}'
.
format
(
num_micro_batches
),
flush
=
True
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
ConstantNumMicroBatches
(
num_micro_batches
)
raise
Exception
(
'should not be here.'
)
class
NumMicroBatchesCalculator
(
ABC
):
def
__init__
(
self
,
name
):
self
.
name
=
name
super
(
NumMicroBatchesCalculator
,
self
).
__init__
()
@
abstractmethod
def
get
(
self
):
pass
def
update
(
self
,
consumed_samples
):
pass
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
num_micro_batches
=
1
):
assert
num_micro_batches
>=
1
self
.
num_micro_batches
=
num_micro_batches
super
(
ConstantNumMicroBatches
,
self
).
__init__
(
'constant: {}'
.
format
(
self
.
num_micro_batches
))
def
update
(
self
,
consumed_samples
):
pass
def
get
(
self
):
return
self
.
num_micro_batches
def
_build_tokenizer
(
args
):
"""Initialize tokenizer."""
global
_GLOBAL_TOKENIZER
...
...
megatron/training.py
View file @
6ea23928
...
...
@@ -25,6 +25,8 @@ from apex.optimizers import FusedAdam as Adam
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_num_microbatches
from
megatron
import
update_num_microbatches
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_last
...
...
@@ -137,10 +139,6 @@ def get_model(model_provider_func):
if
args
.
fp16
:
model
=
FP16_Module
(
model
)
# Wrap model for distributed training."""
if
args
.
num_microbatches
>
1
:
assert
args
.
DDP_impl
==
'local'
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
model
=
torchDDP
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
...
...
@@ -225,6 +223,10 @@ def setup_model_and_optimizer(model_provider_func):
else
:
args
.
iteration
=
0
# Wrap model for distributed training."""
if
get_num_microbatches
()
>
1
:
assert
args
.
DDP_impl
==
'local'
# get model without FP16 and/or TorchDDP wrappers
unwrapped_model
=
model
while
hasattr
(
unwrapped_model
,
'module'
):
...
...
@@ -315,7 +317,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
args
.
num_microbatches
output_tensor
=
loss
/
get_
num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
else
:
timers
(
'forward-send'
).
start
()
...
...
@@ -375,7 +377,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
args
.
num_microbatches
output_tensor
=
loss
/
get_
num_microbatches
()
output_tensor_grad
=
None
losses_reduced
.
append
(
loss_reduced
)
else
:
...
...
@@ -419,10 +421,10 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
args
=
get_args
()
losses_reduced
=
[]
for
i
in
range
(
args
.
num_microbatches
):
for
i
in
range
(
get_
num_microbatches
()
):
timers
(
'forward-compute'
).
start
()
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
=
None
)
output_tensor
=
loss
/
args
.
num_microbatches
output_tensor
=
loss
/
get_
num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
...
...
@@ -441,7 +443,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
args
=
get_args
()
# Compute number of warmup microbatches.
num_microbatches
=
args
.
num_microbatches
num_microbatches
=
get_
num_microbatches
()
num_warmup_microbatches
=
\
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
...
...
@@ -695,6 +697,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
update_num_microbatches
(
args
.
consumed_train_samples
)
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
...
...
@@ -703,7 +706,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration
+=
1
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
args
.
num_microbatches
get_
num_microbatches
()
# Logging.
loss_scale
=
None
...
...
@@ -761,7 +764,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
for
_
in
range
(
args
.
num_microbatches
):
for
_
in
range
(
get_
num_microbatches
()
):
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
...
...
@@ -789,12 +792,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
args
.
num_microbatches
*
get_
num_microbatches
()
# Move model back to the train mode.
model
.
train
()
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
args
.
num_microbatches
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_
num_microbatches
()
return
total_loss_dict
...
...
@@ -834,13 +837,12 @@ def build_train_valid_test_data_iterators(
# Rank and global batch size.
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
global_batch_size
=
args
.
micro_batch_size
*
data_parallel_size
*
args
.
num_microbatches
# Backward compatibility, assume fixed batch size.
if
args
.
iteration
>
0
and
args
.
consumed_train_samples
==
0
:
args
.
consumed_train_samples
=
args
.
iteration
*
global_batch_size
args
.
consumed_train_samples
=
args
.
iteration
*
args
.
global_batch_size
if
args
.
iteration
>
0
and
args
.
consumed_valid_samples
==
0
:
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
*
global_batch_size
args
.
eval_iters
*
args
.
global_batch_size
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
...
...
@@ -849,9 +851,9 @@ def build_train_valid_test_data_iterators(
train_iters
=
args
.
train_iters
eval_iters
=
(
train_iters
//
args
.
eval_interval
+
1
)
*
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_iters
*
global_batch_size
,
eval_iters
*
global_batch_size
,
test_iters
*
global_batch_size
]
train_val_test_num_samples
=
[
train_iters
*
args
.
global_batch_size
,
eval_iters
*
args
.
global_batch_size
,
test_iters
*
args
.
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment