Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0c077a2c
Commit
0c077a2c
authored
May 12, 2020
by
Neel Kant
Browse files
Merge branch 'master' into realm-mlm
parents
150f2384
e9eef962
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
78 additions
and
16 deletions
+78
-16
megatron/arguments.py
megatron/arguments.py
+18
-0
megatron/data/samplers.py
megatron/data/samplers.py
+15
-2
megatron/initialize.py
megatron/initialize.py
+8
-4
megatron/model/transformer.py
megatron/model/transformer.py
+37
-10
No files found.
megatron/arguments.py
View file @
0c077a2c
...
...
@@ -89,6 +89,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
min_lr
<=
args
.
lr
if
args
.
save
is
not
None
:
assert
args
.
save_interval
is
not
None
# Parameters sharing does not work with torch DDP.
if
(
args
.
num_unique_layers
is
not
None
)
and
(
args
.
num_layers
is
not
None
):
assert
args
.
num_unique_layers
<=
args
.
num_layers
assert
args
.
num_layers
%
args
.
num_unique_layers
==
0
,
\
'num-layers should be divisible by num-unique-layers.'
if
args
.
num_unique_layers
<
args
.
num_layers
:
assert
args
.
DDP_impl
==
'local'
,
\
'torch-DDP does not work with parameters sharing.'
_print_args
(
args
)
return
args
...
...
@@ -116,6 +124,16 @@ def _add_network_size_args(parser):
group
.
add_argument
(
'--num-layers'
,
type
=
int
,
default
=
None
,
help
=
'Number of transformer layers.'
)
group
.
add_argument
(
'--num-unique-layers'
,
type
=
int
,
default
=
None
,
help
=
'Number of unique transformer layers. '
'`num-layers` should be divisible by this value.'
)
group
.
add_argument
(
'--param-sharing-style'
,
default
=
'grouped'
,
choices
=
[
'grouped'
,
'spaced'
],
help
=
'Ordering of the shared parameters. For example, '
'for a `num-layers`=4 and `--num-unique-layers`=2, '
'we will have the following ordering for two unique '
'layers 1 and 2: '
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].'
)
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Tansformer hidden size.'
)
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
None
,
...
...
megatron/data/samplers.py
View file @
0c077a2c
...
...
@@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler."""
sampler.
The `interleave` argument specifies how to distribute a batch. A value
of True combined with the above random sampler is equivalent to pytorch's
torch.utils.data.distributed.DistributedSampler.
For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
specifying True will result in the following samples for each gpu:
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
specifying False will result in the following samples:
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
def
__init__
(
self
,
sampler
,
batch_size
,
drop_last
,
rank
=-
1
,
world_size
=
2
,
wrap_last
=
False
):
world_size
=
2
,
wrap_last
=
False
,
interleave
=
False
):
super
(
DistributedBatchSampler
,
self
).
__init__
(
sampler
,
batch_size
,
drop_last
)
if
rank
==
-
1
:
...
...
@@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
self
.
wrap_around
=
0
self
.
wrap_last
=
wrap_last
self
.
start_iter
=
0
self
.
interleave
=
interleave
def
__iter__
(
self
):
batch
=
[]
...
...
@@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def
_batch
(
self
,
batch
):
"""extracts samples only pertaining to this worker's batch"""
if
self
.
interleave
:
return
batch
[
self
.
rank
:
self
.
batch_size
:
self
.
world_size
]
start
=
self
.
rank
*
self
.
batch_size
//
self
.
world_size
end
=
(
self
.
rank
+
1
)
*
self
.
batch_size
//
self
.
world_size
return
batch
[
start
:
end
]
megatron/initialize.py
View file @
0c077a2c
...
...
@@ -29,11 +29,15 @@ from megatron.global_vars import set_global_variables
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing."""
if
not
allow_no_cuda
:
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
...
...
megatron/model/transformer.py
View file @
0c077a2c
...
...
@@ -360,34 +360,60 @@ class ParallelTransformer(MegatronModule):
self
.
checkpoint_activations
=
args
.
checkpoint_activations
self
.
checkpoint_num_layers
=
args
.
checkpoint_num_layers
def
get_layer
(
layer_number
):
# Number of layers:
self
.
num_layers
=
args
.
num_layers
self
.
num_unique_layers
=
args
.
num_unique_layers
if
self
.
num_unique_layers
is
None
:
self
.
num_unique_layers
=
self
.
num_layers
assert
self
.
num_layers
%
self
.
num_unique_layers
==
0
,
\
'number of layers should be divisible by number of unique layers'
self
.
param_sharing_style
=
args
.
param_sharing_style
# Transformer layers.
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
layer_number
)
# Transformer layers.
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
get_layer
(
i
+
1
)
for
i
in
range
(
args
.
num_layers
)])
[
build_layer
(
i
+
1
)
for
i
in
range
(
self
.
num_unique_layers
)])
# Print layer ordering.
if
self
.
num_layers
!=
self
.
num_unique_layers
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> will be using the following layer ordering:'
)
for
i
in
range
(
self
.
num_layers
):
print
(
' layer id: {:3d} --> unique layer id: '
'{:3d}'
.
format
(
i
,
self
.
_get_layer_index
(
i
)),
flush
=
True
)
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
def
_get_layer_index
(
self
,
layer_number
):
if
self
.
param_sharing_style
==
'grouped'
:
return
layer_number
%
self
.
num_unique_layers
if
self
.
param_sharing_style
==
'spaced'
:
return
layer_number
//
(
self
.
num_layers
//
self
.
num_unique_layers
)
assert
False
,
'should not be here'
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
self
.
_get_layer_index
(
layer_number
)]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom_forward
(
*
inputs
):
layers_
=
self
.
layers
[
start
:
end
]
x_
=
inputs
[
0
]
for
layer
in
layers_
:
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
inputs
[
1
])
return
x_
return
custom_forward
l
=
0
num_layers
=
len
(
self
.
layers
)
while
l
<
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
hidden_states
,
attention_mask
)
...
...
@@ -414,10 +440,11 @@ class ParallelTransformer(MegatronModule):
else
:
if
get_key_value
:
presents
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
past
=
None
if
layer_past
is
not
None
:
past
=
layer_past
[
i
]
past
=
layer_past
[
i
ndex
]
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
layer_past
=
past
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment