Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f10b4b89
Unverified
Commit
f10b4b89
authored
Mar 24, 2022
by
Masaki Kozuki
Committed by
GitHub
Mar 24, 2022
Browse files
[transformer] `parallel_state`: Position Embedding (#1343)
* update * Add comment to `destroy_model_parallel`
parent
28f8539c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
155 additions
and
41 deletions
+155
-41
apex/transformer/parallel_state.py
apex/transformer/parallel_state.py
+155
-41
No files found.
apex/transformer/parallel_state.py
View file @
f10b4b89
...
@@ -12,12 +12,22 @@
...
@@ -12,12 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# TODO (mkozuki): Replace assert with RuntimeError.
# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py
"""Model and data parallel groups."""
"""Model and data parallel groups."""
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
import
torch
import
torch
from
apex.transformer.utils
import
ensure_divisibility
from
apex.transformer.log_util
import
get_transformer_logger
_logger
=
get_transformer_logger
(
__name__
)
# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) =
# {
# 'get_num_layers',
# }
# Intra-layer model parallel group that the current rank belongs to.
# Intra-layer model parallel group that the current rank belongs to.
...
@@ -28,6 +38,8 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
...
@@ -28,6 +38,8 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP
=
None
_MODEL_PARALLEL_GROUP
=
None
# Embedding group.
# Embedding group.
_EMBEDDING_GROUP
=
None
_EMBEDDING_GROUP
=
None
# Position embedding group.
_POSITION_EMBEDDING_GROUP
=
None
# Data parallel group that the current rank belongs to.
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
...
@@ -44,6 +56,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
...
@@ -44,6 +56,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS
=
None
_PIPELINE_GLOBAL_RANKS
=
None
...
@@ -87,28 +102,46 @@ def initialize_model_parallel(
...
@@ -87,28 +102,46 @@ def initialize_model_parallel(
"""
"""
# Get world size and rank. Ensure some consistencies.
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
assert
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
=
min
(
tensor_model_parallel_size_
,
world_size
)
tensor_model_parallel_size
:
int
=
min
(
tensor_model_parallel_size_
,
world_size
)
pipeline_model_parallel_size
=
min
(
pipeline_model_parallel_size_
,
world_size
)
pipeline_model_parallel_size
:
int
=
min
(
pipeline_model_parallel_size_
,
world_size
)
ensure_divisibility
(
world_size
,
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
if
world_size
%
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
!=
0
:
data_parallel_size
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
raise
RuntimeError
(
f
"`world_size` (
{
world_size
}
) is not divisible by tensor_model_parallel_size (
{
tensor_model_parallel_size
}
) x pipeline_model_parallel_size (
{
pipeline_model_parallel_size
}
)"
)
data_parallel_size
:
int
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> initializing tensor model parallel with size {}"
.
format
(
tensor_model_parallel_size
))
_logger
.
info
(
print
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size
))
"> initializing tensor model parallel with size {}"
.
format
(
print
(
"> initializing data parallel with size {}"
.
format
(
data_parallel_size
))
tensor_model_parallel_size
)
)
_logger
.
info
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size
)
)
_logger
.
info
(
"> initializing data parallel with size {}"
.
format
(
data_parallel_size
)
)
num_tensor_model_parallel_groups
=
world_size
//
tensor_model_parallel_size
num_tensor_model_parallel_groups
:
int
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
num_data_parallel_groups
:
int
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
if
virtual_pipeline_model_parallel_size_
is
not
None
:
assert
pipeline_model_parallel_size_
>
2
,
\
# assert pipeline_model_parallel_size_ > 2, (
'pipeline-model-parallel size should be greater than 2 with '
\
# "pipeline-model-parallel size should be greater than 2 with "
'interleaved schedule'
# "interleaved schedule"
# )
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
(
virtual_pipeline_model_parallel_size_
)
if
pipeline_model_parallel_split_rank_
is
not
None
:
if
pipeline_model_parallel_split_rank_
is
not
None
:
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
...
@@ -134,16 +167,23 @@ def initialize_model_parallel(
...
@@ -134,16 +167,23 @@ def initialize_model_parallel(
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
"model parallel group is already initialized"
assert
_MODEL_PARALLEL_GROUP
is
None
,
"model parallel group is already initialized"
for
i
in
range
(
data_parallel_size
):
for
i
in
range
(
data_parallel_size
):
ranks
=
[
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_data_parallel_group_ranks
]
ranks
=
[
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_data_parallel_group_ranks
]
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_MODEL_PARALLEL_GROUP
=
group
_MODEL_PARALLEL_GROUP
=
group
# Build the tensor model-parallel groups.
# Build the tensor model-parallel groups.
global
_TENSOR_MODEL_PARALLEL_GROUP
global
_TENSOR_MODEL_PARALLEL_GROUP
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
None
,
"tensor model parallel group is already initialized"
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
),
"tensor model parallel group is already initialized"
for
i
in
range
(
num_tensor_model_parallel_groups
):
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
ranks
=
list
(
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_TENSOR_MODEL_PARALLEL_GROUP
=
group
_TENSOR_MODEL_PARALLEL_GROUP
=
group
...
@@ -152,10 +192,17 @@ def initialize_model_parallel(
...
@@ -152,10 +192,17 @@ def initialize_model_parallel(
# (first and last rank in each pipeline model-parallel group).
# (first and last rank in each pipeline model-parallel group).
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_GLOBAL_RANKS
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
"pipeline model parallel group is already initialized"
assert
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
),
"pipeline model parallel group is already initialized"
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
"embedding group is already initialized"
assert
_EMBEDDING_GROUP
is
None
,
"embedding group is already initialized"
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GLOBAL_RANKS
assert
(
_POSITION_EMBEDDING_GROUP
is
None
),
"position embedding group is already initialized"
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
...
@@ -167,22 +214,38 @@ def initialize_model_parallel(
...
@@ -167,22 +214,38 @@ def initialize_model_parallel(
if
len
(
ranks
)
>
1
:
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
position_embedding_ranks
=
[
ranks
[
0
]]
position_embedding_ranks
=
[
ranks
[
0
]]
if
pipeline_model_parallel_split_rank_
is
not
None
:
if
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
ranks
[
-
1
],
]
if
(
if
(
pipeline_model_parallel_split_rank_
is
not
None
and
ranks
[
pipeline_model_parallel_split_rank_
]
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
embedding_ranks
not
in
position_
embedding_ranks
):
):
if
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
embedding_ranks
:
position_
embedding_ranks
=
[
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
ranks
[
-
1
]]
ranks
[
0
],
if
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
position_embedding_ranks
:
ranks
[
pipeline_model_parallel_split_rank_
]
,
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
]
]
]
else
:
else
:
embedding_ranks
=
ranks
embedding_ranks
=
ranks
position_embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
if
rank
in
embedding_ranks
:
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
group
=
torch
.
distributed
.
new_group
(
position_embedding_ranks
)
if
rank
in
position_embedding_ranks
:
_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
def
get_rank_info
()
->
Tuple
[
int
,
int
,
int
]:
def
get_rank_info
()
->
Tuple
[
int
,
int
,
int
]:
"""Returns a tuple of (data, tensor, pipeline, virtual pipeline)-parallel-rank for logger."""
"""Returns a tuple of (data, tensor, pipeline, virtual pipeline)-parallel-rank for logger."""
if
model_parallel_is_initialized
():
if
model_parallel_is_initialized
():
...
@@ -197,7 +260,11 @@ def get_rank_info() -> Tuple[int, int, int]:
...
@@ -197,7 +260,11 @@ def get_rank_info() -> Tuple[int, int, int]:
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
"""Check if model and data parallel groups are initialized."""
if
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
:
if
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
):
return
False
return
False
return
True
return
True
...
@@ -210,13 +277,17 @@ def get_model_parallel_group():
...
@@ -210,13 +277,17 @@ def get_model_parallel_group():
def
get_tensor_model_parallel_group
():
def
get_tensor_model_parallel_group
():
"""Get the tensor model parallel group the caller rank belongs to."""
"""Get the tensor model parallel group the caller rank belongs to."""
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
,
"intra_layer_model parallel group is not initialized"
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
),
"intra_layer_model parallel group is not initialized"
return
_TENSOR_MODEL_PARALLEL_GROUP
return
_TENSOR_MODEL_PARALLEL_GROUP
def
get_pipeline_model_parallel_group
():
def
get_pipeline_model_parallel_group
():
"""Get the pipeline model parallel group the caller rank belongs to."""
"""Get the pipeline model parallel group the caller rank belongs to."""
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
,
"pipeline_model parallel group is not initialized"
assert
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
),
"pipeline_model parallel group is not initialized"
return
_PIPELINE_MODEL_PARALLEL_GROUP
return
_PIPELINE_MODEL_PARALLEL_GROUP
...
@@ -232,6 +303,14 @@ def get_embedding_group():
...
@@ -232,6 +303,14 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
return
_EMBEDDING_GROUP
def
get_position_embedding_group
():
"""Get the position embedding group the caller rank belongs to."""
assert
(
_POSITION_EMBEDDING_GROUP
is
not
None
),
"position embedding group is not initialized"
return
_POSITION_EMBEDDING_GROUP
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
"""Return true if current rank is in embedding group, False otherwise."""
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
...
@@ -248,6 +327,13 @@ def is_rank_in_embedding_group(ignore_virtual=False):
...
@@ -248,6 +327,13 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return
False
return
False
def
is_rank_in_position_embedding_group
():
"""Return whether the current rank is in position embedding group."""
rank
=
torch
.
distributed
.
get_rank
()
global
_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_pipeline_stage_before_split
(
rank
=
None
):
def
is_pipeline_stage_before_split
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
with both encoder and decoder."""
...
@@ -283,7 +369,9 @@ def is_pipeline_stage_at_split():
...
@@ -283,7 +369,9 @@ def is_pipeline_stage_at_split():
stage executes encoder block for a model with both encoder and
stage executes encoder block for a model with both encoder and
decoder."""
decoder."""
rank
=
get_pipeline_model_parallel_rank
()
rank
=
get_pipeline_model_parallel_rank
()
return
is_pipeline_stage_before_split
(
rank
)
and
is_pipeline_stage_after_split
(
rank
+
1
)
return
is_pipeline_stage_before_split
(
rank
)
and
is_pipeline_stage_after_split
(
rank
+
1
)
def
set_tensor_model_parallel_world_size
(
world_size
):
def
set_tensor_model_parallel_world_size
(
world_size
):
...
@@ -342,6 +430,9 @@ def get_pipeline_model_parallel_rank():
...
@@ -342,6 +430,9 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
# TODO (mkozuki): Add [`get_num_layers`](https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/mpu/initialize.py#L321) here, maybe?
def
get_pipeline_model_parallel_split_rank
():
def
get_pipeline_model_parallel_split_rank
():
"""Return my rank for the pipeline model parallel split rank."""
"""Return my rank for the pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
...
@@ -368,12 +459,16 @@ def is_pipeline_first_stage(ignore_virtual=False):
...
@@ -368,12 +459,16 @@ def is_pipeline_first_stage(ignore_virtual=False):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
get_virtual_pipeline_model_parallel_world_size
()
virtual_pipeline_model_parallel_world_size
=
(
get_virtual_pipeline_model_parallel_world_size
()
)
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
(
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
virtual_pipeline_model_parallel_world_size
-
1
):
):
return
False
return
False
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
def
get_virtual_pipeline_model_parallel_rank
():
def
get_virtual_pipeline_model_parallel_rank
():
...
@@ -402,26 +497,42 @@ def get_tensor_model_parallel_src_rank():
...
@@ -402,26 +497,42 @@ def get_tensor_model_parallel_src_rank():
return
(
global_rank
//
local_world_size
)
*
local_world_size
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_data_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank in the data parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
data_parallel_size
:
int
=
get_data_parallel_world_size
()
num_data_parallel_groups
=
torch
.
distributed
.
get_world_size
()
//
data_parallel_size
return
global_rank
%
num_data_parallel_groups
def
get_pipeline_model_parallel_first_rank
():
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_next_rank
():
def
get_pipeline_model_parallel_next_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
def
get_pipeline_model_parallel_prev_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
...
@@ -437,6 +548,9 @@ def get_data_parallel_rank():
...
@@ -437,6 +548,9 @@ def get_data_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
# note (mkozuki): `destroy_model_parallel` voids more global variables than Megatron-LM.
# Otherwise pipeline parallel forward_backward functions test hangs possibly because
# the clean-up of the original is NOT enough.
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none."""
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment