Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
17843605
Commit
17843605
authored
Dec 17, 2021
by
Vijay Korthikanti
Browse files
pipeline_fixes
parent
945ece94
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
10 deletions
+63
-10
megatron/model/language_model.py
megatron/model/language_model.py
+0
-4
megatron/model/transformer.py
megatron/model/transformer.py
+10
-1
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+2
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+42
-5
megatron/training.py
megatron/training.py
+9
-0
No files found.
megatron/model/language_model.py
View file @
17843605
...
@@ -334,10 +334,6 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -334,10 +334,6 @@ class TransformerLanguageModel(MegatronModule):
# Decoder (usually set to False, True if part of an encoder-decoder
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
# architecture and in decoder-only stage).
if
self
.
add_decoder
:
if
self
.
add_decoder
:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert
args
.
pipeline_model_parallel_size
==
1
,
\
'pipeline parallelism is not supported in the presence of decoder'
self
.
decoder
=
ParallelTransformer
(
self
.
decoder
=
ParallelTransformer
(
self
.
init_method
,
self
.
init_method
,
output_layer_init_method
,
output_layer_init_method
,
...
...
megatron/model/transformer.py
View file @
17843605
...
@@ -580,6 +580,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -580,6 +580,7 @@ class ParallelTransformer(MegatronModule):
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
'virtual_pipeline_model_parallel_size'
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
...
@@ -596,7 +597,15 @@ class ParallelTransformer(MegatronModule):
...
@@ -596,7 +597,15 @@ class ParallelTransformer(MegatronModule):
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
else
:
# Each stage gets a contiguous set of layers.
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
:
pipeline_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
if
layer_type
==
LayerType
.
encoder
:
offset
=
pipeline_rank
*
self
.
num_layers
else
:
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
else
:
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
...
...
megatron/mpu/__init__.py
View file @
17843605
...
@@ -25,6 +25,7 @@ from .initialize import get_data_parallel_group
...
@@ -25,6 +25,7 @@ from .initialize import get_data_parallel_group
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_world_size
from
.initialize
import
get_data_parallel_world_size
from
.initialize
import
get_embedding_group
from
.initialize
import
get_embedding_group
from
.initialize
import
get_position_embedding_group
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_group
from
.initialize
import
get_pipeline_model_parallel_group
from
.initialize
import
get_pipeline_model_parallel_group
...
@@ -32,6 +33,7 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
...
@@ -32,6 +33,7 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
is_rank_in_embedding_group
from
.initialize
import
is_rank_in_embedding_group
from
.initialize
import
is_rank_in_position_embedding_group
from
.initialize
import
is_pipeline_stage_before_split
,
is_pipeline_stage_after_split
from
.initialize
import
is_pipeline_stage_before_split
,
is_pipeline_stage_after_split
from
.initialize
import
is_pipeline_stage_at_split
from
.initialize
import
is_pipeline_stage_at_split
from
.initialize
import
get_num_layers
from
.initialize
import
get_num_layers
...
...
megatron/mpu/initialize.py
View file @
17843605
...
@@ -29,6 +29,8 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
...
@@ -29,6 +29,8 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP
=
None
_MODEL_PARALLEL_GROUP
=
None
# Embedding group.
# Embedding group.
_EMBEDDING_GROUP
=
None
_EMBEDDING_GROUP
=
None
# Position embedding group.
_POSITION
EMBEDDING_GROUP
=
None
# Data parallel group that the current rank belongs to.
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
...
@@ -45,6 +47,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
...
@@ -45,6 +47,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS
=
None
_PIPELINE_GLOBAL_RANKS
=
None
...
@@ -165,6 +170,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -165,6 +170,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
global
_EMBEDDING_GLOBAL_RANKS
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
\
assert
_EMBEDDING_GROUP
is
None
,
\
'embedding group is already initialized'
'embedding group is already initialized'
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GLOBAL_RANKS
assert
_POSITION_EMBEDDING_GROUP
is
None
,
\
'position embedding group is already initialized'
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
num_pipeline_model_parallel_groups
)
...
@@ -176,19 +185,31 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -176,19 +185,31 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages).
# first and last stages).
if
len
(
ranks
)
>
1
:
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
if
pipeline_model_parallel_split_rank_
is
not
None
and
\
position_embedding_ranks
=
[
ranks
[
0
]]
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
embedding_ranks
:
if
pipeline_model_parallel_split_rank_
is
not
None
:
embedding_ranks
=
[
ranks
[
0
],
if
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
embedding_ranks
:
ranks
[
pipeline_model_parallel_split_rank_
],
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
ranks
[
pipeline_model_parallel_split_rank_
],
ranks
[
-
1
]]
if
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
position_embedding_ranks
:
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
]]
else
:
else
:
embedding_ranks
=
ranks
embedding_ranks
=
ranks
position_embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
if
rank
in
embedding_ranks
:
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
group
=
torch
.
distributed
.
new_group
(
position_embedding_ranks
)
if
rank
in
position_embedding_ranks
:
_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_POSITION_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
"""Check if model and data parallel groups are initialized."""
...
@@ -234,6 +255,13 @@ def get_embedding_group():
...
@@ -234,6 +255,13 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
return
_EMBEDDING_GROUP
def
get_position_embedding_group
():
"""Get the position embedding group the caller rank belongs to."""
assert
_POSITION_EMBEDDING_GROUP
is
not
None
,
\
'position embedding group is not initialized'
return
_POSITION_EMBEDDING_GROUP
def
set_tensor_model_parallel_world_size
(
world_size
):
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor model parallel size"""
"""Set the tensor model parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
@@ -352,6 +380,13 @@ def is_rank_in_embedding_group(ignore_virtual=False):
...
@@ -352,6 +380,13 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return
False
return
False
def
is_rank_in_position_embedding_group
():
"""Return true if current rank is in position embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_pipeline_stage_before_split
(
rank
=
None
):
def
is_pipeline_stage_before_split
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
with both encoder and decoder."""
...
@@ -467,3 +502,5 @@ def destroy_model_parallel():
...
@@ -467,3 +502,5 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
_EMBEDDING_GROUP
=
None
global
_POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP
=
None
megatron/training.py
View file @
17843605
...
@@ -427,6 +427,15 @@ def train_step(forward_step_func, data_iterator,
...
@@ -427,6 +427,15 @@ def train_step(forward_step_func, data_iterator,
else
:
else
:
grad
=
word_embeddings_weight
.
grad
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Update parameters.
# Update parameters.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment