Unverified Commit 5ff5a884 authored by Perkz Zheng's avatar Perkz Zheng Committed by GitHub
Browse files

update: mpu for t5 rpe (#1416)



* update: mpu for t5 rpe

* update: add rpe mpu group test

* fix semicolon bugs
Co-authored-by: default avatarMasaki Kozuki <masaki.kozuki.2014@gmail.com>

* fix semicolon bugs
Co-authored-by: default avatarMasaki Kozuki <masaki.kozuki.2014@gmail.com>
Co-authored-by: default avatarMasaki Kozuki <masaki.kozuki.2014@gmail.com>
parent 8a7a3325
......@@ -42,6 +42,9 @@ _MODEL_PARALLEL_GROUP = None
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION_EMBEDDING_GROUP = None
# Relative position embedding group.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
......@@ -61,6 +64,10 @@ _EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the relative position embedding.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None
......@@ -228,6 +235,13 @@ def initialize_model_parallel(
assert (
_POSITION_EMBEDDING_GROUP is None
), "position embedding group is already initialized"
global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
assert _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is None or \
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is None, \
'relative position embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks, backend=p2p_backend)
......@@ -236,10 +250,18 @@ def initialize_model_parallel(
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
encoder_relative_position_embedding_ranks = None
decoder_relative_position_embedding_ranks = None
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
encoder_relative_position_embedding_ranks = [ranks[0]]
decoder_relative_position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank_ is not None:
encoder_relative_position_embedding_ranks = \
ranks[:pipeline_model_parallel_split_rank_]
decoder_relative_position_embedding_ranks = \
ranks[pipeline_model_parallel_split_rank_:]
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
embedding_ranks = [
ranks[0],
......@@ -257,6 +279,8 @@ def initialize_model_parallel(
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
encoder_relative_position_embedding_ranks = ranks
decoder_relative_position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks, backend=default_backend)
if rank in embedding_ranks:
......@@ -270,6 +294,19 @@ def initialize_model_parallel(
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
group = torch.distributed.new_group(encoder_relative_position_embedding_ranks)
if rank in encoder_relative_position_embedding_ranks:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \
encoder_relative_position_embedding_ranks
group = torch.distributed.new_group(decoder_relative_position_embedding_ranks)
if rank in decoder_relative_position_embedding_ranks:
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \
decoder_relative_position_embedding_ranks
def get_rank_info() -> Tuple[int, int, int]:
"""Returns a tuple of (data, tensor, pipeline, virtual pipeline)-parallel-rank for logger."""
......@@ -335,6 +372,17 @@ def get_position_embedding_group():
), "position embedding group is not initialized"
return _POSITION_EMBEDDING_GROUP
def get_encoder_relative_position_embedding_group():
"""Get the encoder relative position embedding group the caller rank belongs to."""
assert _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, \
'encoder relative position embedding group is not initialized'
return _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
def get_decoder_relative_position_embedding_group():
"""Get the decoder relative position embedding group the caller rank belongs to."""
assert _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, \
'decoder relative position embedding group is not initialized'
return _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
......@@ -358,6 +406,17 @@ def is_rank_in_position_embedding_group():
global _POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
def is_rank_in_encoder_relative_position_embedding_group():
"""Return true if current rank is in encoder relative position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
def is_rank_in_decoder_relative_position_embedding_group():
"""Return true if current rank is in decoder relative position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
......@@ -590,6 +649,10 @@ def destroy_model_parallel():
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
......
......@@ -120,6 +120,16 @@ class ParallelStateTestBase:
fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()
)
# relative position embedding groups check
self.assertEqual(
expected_pipeline_rank < pipeline_model_parallel_split_rank,
parallel_state.is_rank_in_encoder_relative_position_embedding_group(),
)
self.assertEqual(
expected_pipeline_rank >= pipeline_model_parallel_split_rank,
parallel_state.is_rank_in_decoder_relative_position_embedding_group(),
)
parallel_state.destroy_model_parallel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment