Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9ad1944d
Commit
9ad1944d
authored
May 26, 2022
by
Vijay Korthikanti
Browse files
bug fix for get_data_parallel_src_rank
parent
739cb43d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
5 deletions
+12
-5
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+12
-5
No files found.
megatron/mpu/initialize.py
View file @
9ad1944d
...
@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None
...
@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None
# rank when broadcasting from the first or last pipeline stage.
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS
=
None
_PIPELINE_GLOBAL_RANKS
=
None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS
=
None
def
is_unitialized
():
def
is_unitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
"""Useful for code segments that may be accessed with or without mpu initialization"""
return
_DATA_PARALLEL_GROUP
is
None
return
_DATA_PARALLEL_GROUP
is
None
...
@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the data-parallel groups.
# Build the data-parallel groups.
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GLOBAL_RANKS
assert
_DATA_PARALLEL_GROUP
is
None
,
\
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group is already initialized'
'data parallel group is already initialized'
all_data_parallel_group_ranks
=
[]
all_data_parallel_group_ranks
=
[]
...
@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GLOBAL_RANKS
=
ranks
# Build the model-parallel groups.
# Build the model-parallel groups.
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
...
@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank():
...
@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank():
def
get_data_parallel_src_rank
():
def
get_data_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
in the data parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
assert
_DATA_PARALLEL_GLOBAL_RANKS
is
not
None
,
\
data_parallel_size
=
get_data_parallel_world_size
()
"Data parallel group is not initialized"
num_data_parallel_groups
=
torch
.
distributed
.
get_world_size
()
//
data_parallel_size
return
_DATA_PARALLEL_GLOBAL_RANKS
[
0
]
return
global_rank
%
num_data_parallel_groups
def
get_pipeline_model_parallel_first_rank
():
def
get_pipeline_model_parallel_first_rank
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment