Unverified Commit 069ff336 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

access to pipeline_model_parallel_split_rank (#1300)

parent ab1a93a7
...@@ -342,6 +342,12 @@ def get_pipeline_model_parallel_rank(): ...@@ -342,6 +342,12 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_pipeline_model_parallel_split_rank():
"""Return my rank for the pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual: if not ignore_virtual:
......
...@@ -80,6 +80,25 @@ def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): ...@@ -80,6 +80,25 @@ def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
# Checks # Checks
src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank() src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank()
assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank is None
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_pipeline_model_parallel_split_rank():
pipeline_model_parallel_split_rank_ = 1
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank_)
assert parallel_state.model_parallel_is_initialized()
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
assert split_rank is pipeline_model_parallel_split_rank_
# Reset groups # Reset groups
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
...@@ -101,4 +120,6 @@ if __name__ == '__main__': ...@@ -101,4 +120,6 @@ if __name__ == '__main__':
test_initialize_model_parallel(tensor_model_parallel_size) test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank') print_separator('test model parallel source rank')
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
print_separator('test pipeline model parallel split rank')
test_pipeline_model_parallel_split_rank()
tensor_model_parallel_size *= 2 tensor_model_parallel_size *= 2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment