Commit 9ad1944d authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

bug fix for get_data_parallel_src_rank

parent 739cb43d
...@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None ...@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None
# rank when broadcasting from the first or last pipeline stage. # rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization""" """Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None return _DATA_PARALLEL_GROUP is None
...@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the data-parallel groups. # Build the data-parallel groups.
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, \ assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized' 'data parallel group is already initialized'
all_data_parallel_group_ranks = [] all_data_parallel_group_ranks = []
...@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
_DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GLOBAL_RANKS = ranks
# Build the model-parallel groups. # Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
...@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank(): ...@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank():
def get_data_parallel_src_rank(): def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank """Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.""" in the data parallel group."""
global_rank = torch.distributed.get_rank() assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
data_parallel_size = get_data_parallel_world_size() "Data parallel group is not initialized"
num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size return _DATA_PARALLEL_GLOBAL_RANKS[0]
return global_rank % num_data_parallel_groups
def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_first_rank():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment