Commit 9ad1944d authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

bug fix for get_data_parallel_src_rank

parent 739cb43d
......@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
......@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
all_data_parallel_group_ranks = []
......@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GLOBAL_RANKS = ranks
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
......@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank():
def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
data_parallel_size = get_data_parallel_world_size()
num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size
return global_rank % num_data_parallel_groups
in the data parallel group."""
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
"Data parallel group is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_first_rank():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment