Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
fa8bd7e6
Unverified
Commit
fa8bd7e6
authored
Nov 09, 2021
by
eqy
Committed by
GitHub
Nov 09, 2021
Browse files
check in (#1205)
parent
ae757634
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
1 deletion
+25
-1
apex/transformer/parallel_state.py
apex/transformer/parallel_state.py
+25
-1
No files found.
apex/transformer/parallel_state.py
View file @
fa8bd7e6
...
...
@@ -40,6 +40,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS
=
None
...
...
@@ -93,6 +96,9 @@ def initialize_model_parallel(
num_data_parallel_groups
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
assert
pipeline_model_parallel_size_
>
2
,
\
'pipeline-model-parallel size should be greater than 2 with '
\
'interleaved schedule'
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
...
...
@@ -138,6 +144,7 @@ def initialize_model_parallel(
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
"pipeline model parallel group is already initialized"
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
"embedding group is already initialized"
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
...
...
@@ -154,7 +161,8 @@ def initialize_model_parallel(
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
...
...
@@ -193,6 +201,22 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_EMBEDDING_GLOBAL_RANKS
if
ignore_virtual
:
return
rank
in
_EMBEDDING_GLOBAL_RANKS
if
rank
in
_EMBEDDING_GLOBAL_RANKS
:
if
rank
==
_EMBEDDING_GLOBAL_RANKS
[
0
]:
return
is_pipeline_first_stage
(
ignore_virtual
=
False
)
elif
rank
==
_EMBEDDING_GLOBAL_RANKS
[
-
1
]:
return
is_pipeline_last_stage
(
ignore_virtual
=
False
)
else
:
return
True
return
False
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor model parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment