Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
069ff336
Unverified
Commit
069ff336
authored
Feb 23, 2022
by
Masaki Kozuki
Committed by
GitHub
Feb 23, 2022
Browse files
access to pipeline_model_parallel_split_rank (#1300)
parent
ab1a93a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
0 deletions
+27
-0
apex/transformer/parallel_state.py
apex/transformer/parallel_state.py
+6
-0
tests/L0/run_transformer/run_initialize_test.py
tests/L0/run_transformer/run_initialize_test.py
+21
-0
No files found.
apex/transformer/parallel_state.py
View file @
069ff336
...
@@ -342,6 +342,12 @@ def get_pipeline_model_parallel_rank():
...
@@ -342,6 +342,12 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
get_pipeline_model_parallel_split_rank
():
"""Return my rank for the pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
not
ignore_virtual
:
...
...
tests/L0/run_transformer/run_initialize_test.py
View file @
069ff336
...
@@ -80,6 +80,25 @@ def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
...
@@ -80,6 +80,25 @@ def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
# Checks
# Checks
src_rank
=
torch
.
distributed
.
get_rank
()
-
parallel_state
.
get_tensor_model_parallel_rank
()
src_rank
=
torch
.
distributed
.
get_rank
()
-
parallel_state
.
get_tensor_model_parallel_rank
()
assert
parallel_state
.
get_tensor_model_parallel_src_rank
()
==
src_rank
assert
parallel_state
.
get_tensor_model_parallel_src_rank
()
==
src_rank
split_rank
=
parallel_state
.
get_pipeline_model_parallel_split_rank
()
assert
split_rank
is
None
# Reset groups
parallel_state
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>> passed the test :-)'
)
def
test_pipeline_model_parallel_split_rank
():
pipeline_model_parallel_split_rank_
=
1
assert
not
parallel_state
.
model_parallel_is_initialized
()
parallel_state
.
initialize_model_parallel
(
pipeline_model_parallel_split_rank_
=
pipeline_model_parallel_split_rank_
)
assert
parallel_state
.
model_parallel_is_initialized
()
split_rank
=
parallel_state
.
get_pipeline_model_parallel_split_rank
()
assert
split_rank
is
pipeline_model_parallel_split_rank_
# Reset groups
# Reset groups
parallel_state
.
destroy_model_parallel
()
parallel_state
.
destroy_model_parallel
()
...
@@ -101,4 +120,6 @@ if __name__ == '__main__':
...
@@ -101,4 +120,6 @@ if __name__ == '__main__':
test_initialize_model_parallel
(
tensor_model_parallel_size
)
test_initialize_model_parallel
(
tensor_model_parallel_size
)
print_separator
(
'test model parallel source rank'
)
print_separator
(
'test model parallel source rank'
)
test_get_tensor_model_parallel_src_rank
(
tensor_model_parallel_size
)
test_get_tensor_model_parallel_src_rank
(
tensor_model_parallel_size
)
print_separator
(
'test pipeline model parallel split rank'
)
test_pipeline_model_parallel_split_rank
()
tensor_model_parallel_size
*=
2
tensor_model_parallel_size
*=
2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment