Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
feecd5d9
Commit
feecd5d9
authored
Dec 07, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Add constant num micro-batches calculator
parent
6ea23928
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
24 additions
and
33 deletions
+24
-33
megatron/arguments.py
megatron/arguments.py
+4
-4
megatron/data/data_loaders.py
megatron/data/data_loaders.py
+17
-25
megatron/global_vars.py
megatron/global_vars.py
+2
-1
megatron/mpu/random.py
megatron/mpu/random.py
+1
-1
megatron/training.py
megatron/training.py
+0
-2
No files found.
megatron/arguments.py
View file @
feecd5d9
...
...
@@ -69,13 +69,13 @@ def parse_args(extra_args_provider=None, defaults={},
raise
Exception
(
'PyTorch with torch.distributed.ring_exchange '
'needed to run pipeline MP!'
)
# Checks.
args
.
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
assert
args
.
world_size
%
args
.
model_parallel_size
==
0
,
'world size is not'
\
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
' divisible by tensor parallel size ({}) times pipeline paralle '
\
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
data_parallel_size
=
args
.
world_size
//
args
.
model_parallel_size
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
if
args
.
rank
==
0
:
print
(
'using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
...
...
megatron/data/data_loaders.py
View file @
feecd5d9
...
...
@@ -29,15 +29,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
return
None
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
# Megatron sampler
batch_sampler
=
MegatronPretrainingSampler
(
total_samples
=
len
(
dataset
),
consumed_samples
=
consumed_samples
,
global
_batch_size
=
args
.
global
_batch_size
,
rank
=
mpu
.
get_data_parallel_rank
(),
world_size
=
world_size
)
micro
_batch_size
=
args
.
micro
_batch_size
,
data_parallel_
rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_size
=
mpu
.
get_data_parallel_
world_size
()
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
...
...
@@ -49,13 +47,15 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class
MegatronPretrainingSampler
:
def
__init__
(
self
,
total_samples
,
consumed_samples
,
global_batch_size
,
rank
,
world
_size
):
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
data_parallel_rank
,
data_parallel
_size
):
# Keep a copy of input params for later use.
self
.
total_samples
=
total_samples
self
.
consumed_samples
=
consumed_samples
self
.
global_batch_size
=
global_batch_size
self
.
rank
=
rank
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_rank
=
data_parallel_rank
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
data_parallel_size
# Sanity checks.
assert
self
.
total_samples
>
0
,
\
...
...
@@ -63,19 +63,11 @@ class MegatronPretrainingSampler:
assert
self
.
consumed_samples
<
self
.
total_samples
,
\
'no samples left to consume: {}, {}'
.
format
(
self
.
consumed_samples
,
self
.
total_samples
)
assert
self
.
global_batch_size
>
0
,
\
'Unexpected global batch size: {}'
.
format
(
self
.
global_batch_size
)
assert
world_size
>
0
,
\
'non zero world size is expected: {}'
.
format
(
world_size
)
assert
self
.
rank
<
world_size
,
\
'rank should be smaller than world size: {}, {}'
.
format
(
self
.
rank
,
world_size
)
# Batch size per rank.
assert
self
.
global_batch_size
%
world_size
==
0
,
\
'global batch size must be divisible by world size: {}, {}'
.
format
(
self
.
global_batch_size
,
world_size
)
self
.
batch_size_per_rank
=
self
.
global_batch_size
//
world_size
assert
self
.
micro_batch_size
>
0
assert
data_parallel_size
>
0
assert
self
.
data_parallel_rank
<
data_parallel_size
,
\
'data_parallel_rank should be smaller than data size: {}, '
\
'{}'
.
format
(
self
.
data_parallel_rank
,
data_parallel_size
)
def
__len__
(
self
):
...
...
@@ -87,8 +79,8 @@ class MegatronPretrainingSampler:
# Last batch if not complete will be dropped.
for
idx
in
range
(
self
.
consumed_samples
,
self
.
total_samples
):
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
global_batch
_size
:
start_idx
=
self
.
rank
*
self
.
batch_size
_per_rank
end_idx
=
start_idx
+
self
.
batch_size
_per_rank
if
len
(
batch
)
==
self
.
micro_batch_times_data_parallel
_size
:
start_idx
=
self
.
data_parallel_
rank
*
self
.
micro_
batch_size
end_idx
=
start_idx
+
self
.
micro_
batch_size
yield
batch
[
start_idx
:
end_idx
]
batch
=
[]
megatron/global_vars.py
View file @
feecd5d9
...
...
@@ -113,7 +113,7 @@ def _build_num_microbatches_calculator(args):
# Constant num micro-batches.
if
args
.
rampup_batch_size
is
None
:
micro_batch_times_data_parallel
=
args
.
micro_batch_size
*
\
arg
.
data_parallel_size
arg
s
.
data_parallel_size
assert
args
.
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
\
'global batch size ({}) is not divisible by micro batch size ({})'
\
' times data parallel size ({})'
.
format
(
args
.
global_batch_size
,
...
...
@@ -126,6 +126,7 @@ def _build_num_microbatches_calculator(args):
num_micro_batches
),
flush
=
True
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
ConstantNumMicroBatches
(
num_micro_batches
)
return
raise
Exception
(
'should not be here.'
)
...
...
megatron/mpu/random.py
View file @
feecd5d9
...
...
@@ -46,7 +46,7 @@ def init_checkpointed_activations_memory_buffer():
args
=
get_args
()
per_layer
=
args
.
micro_batch_size
*
args
.
max_position_embeddings
*
\
args
.
hidden_size
//
args
.
model_parallel_size
args
.
hidden_size
//
args
.
tensor_
model_parallel_size
assert
args
.
num_layers
%
args
.
checkpoint_num_layers
==
0
,
\
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers
=
args
.
num_layers
//
args
.
checkpoint_num_layers
...
...
megatron/training.py
View file @
feecd5d9
...
...
@@ -835,8 +835,6 @@ def build_train_valid_test_data_iterators(
print_rank_0
(
'> building train, validation, and test datasets ...'
)
# Rank and global batch size.
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
# Backward compatibility, assume fixed batch size.
if
args
.
iteration
>
0
and
args
.
consumed_train_samples
==
0
:
args
.
consumed_train_samples
=
args
.
iteration
*
args
.
global_batch_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment