Unverified Commit e95c3b9c authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

add setter of pipeline model parallel split rank (#1306)

parent 17e1a1f6
...@@ -348,6 +348,12 @@ def get_pipeline_model_parallel_split_rank(): ...@@ -348,6 +348,12 @@ def get_pipeline_model_parallel_split_rank():
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def set_pipeline_model_parallel_split_rank(pipeline_model_parallel_split_rank: int):
"""Set my rank for the pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual: if not ignore_virtual:
......
...@@ -100,6 +100,11 @@ def test_pipeline_model_parallel_split_rank(): ...@@ -100,6 +100,11 @@ def test_pipeline_model_parallel_split_rank():
split_rank = parallel_state.get_pipeline_model_parallel_split_rank() split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank is pipeline_model_parallel_split_rank_ assert split_rank is pipeline_model_parallel_split_rank_
fake_split_rank = 7
parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank == fake_split_rank
# Reset groups # Reset groups
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment