Commit 3af6725d authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working for t5 [ encoder embedding only ]

parent 1fa6990c
...@@ -329,16 +329,35 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -329,16 +329,35 @@ def get_num_layers(args, is_encoder_and_decoder_model):
if get_pipeline_model_parallel_world_size() > 1: if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model: if is_encoder_and_decoder_model:
# >>> # >>>
raise Exception("fix for t5.") # raise Exception("fix for t5.")
# <<< # <<<
assert args.pipeline_model_parallel_split_rank is not None assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank # >>>
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder # num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
# +++
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embed_stage else
args.pipeline_model_parallel_split_rank
)
# <<<
# >>>
# num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
# +++
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
# <<<
# >>>
# raise Exception(">>>> standalone %d, encoder %d, decoder %d. <<<<" % (
# args.standalone_embed_stage,
# num_ranks_in_encoder,
# num_ranks_in_decoder,
# ))
# <<<
assert args.num_layers % num_ranks_in_encoder == 0, \ assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers must be divisible by number of ranks given to encoder' 'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder)
assert args.num_layers % num_ranks_in_decoder == 0, \ assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers must be divisible by number of ranks given to decoder' 'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split(): if is_pipeline_stage_before_split(): # args):
num_layers = args.num_layers // num_ranks_in_encoder num_layers = args.num_layers // num_ranks_in_encoder
else: else:
num_layers = args.num_layers // num_ranks_in_decoder num_layers = args.num_layers // num_ranks_in_decoder
...@@ -419,6 +438,9 @@ def is_rank_in_position_embedding_group(): ...@@ -419,6 +438,9 @@ def is_rank_in_position_embedding_group():
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
# >>>
# def is_pipeline_stage_before_split(args, rank=None):
# <<<
def is_pipeline_stage_before_split(rank=None): def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model """Return True if pipeline stage executes encoder block for a model
with both encoder and decoder.""" with both encoder and decoder."""
...@@ -426,6 +448,11 @@ def is_pipeline_stage_before_split(rank=None): ...@@ -426,6 +448,11 @@ def is_pipeline_stage_before_split(rank=None):
return True return True
if rank is None: if rank is None:
rank = get_pipeline_model_parallel_rank() rank = get_pipeline_model_parallel_rank()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
# <<<
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True return True
...@@ -434,6 +461,9 @@ def is_pipeline_stage_before_split(rank=None): ...@@ -434,6 +461,9 @@ def is_pipeline_stage_before_split(rank=None):
return False return False
# >>>
# def is_pipeline_stage_after_split(args, rank=None):
# <<<
def is_pipeline_stage_after_split(rank=None): def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model """Return True if pipeline stage executes decoder block for a model
with both encoder and decoder.""" with both encoder and decoder."""
...@@ -441,6 +471,11 @@ def is_pipeline_stage_after_split(rank=None): ...@@ -441,6 +471,11 @@ def is_pipeline_stage_after_split(rank=None):
return True return True
if rank is None: if rank is None:
rank = get_pipeline_model_parallel_rank() rank = get_pipeline_model_parallel_rank()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
# <<<
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True return True
......
...@@ -254,8 +254,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -254,8 +254,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
pre_process = rank == 0 or rank == split_rank pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or ( post_process = (rank == (split_rank - 1)) or (
rank == (world_size - 1)) rank == (world_size - 1))
add_encoder = mpu.is_pipeline_stage_before_split() # >>>
add_decoder = mpu.is_pipeline_stage_after_split() add_encoder = mpu.is_pipeline_stage_before_split() # args)
add_decoder = mpu.is_pipeline_stage_after_split() # args)
# <<<
model = model_provider_func( model = model_provider_func(
pre_process=pre_process, pre_process=pre_process,
post_process=post_process, post_process=post_process,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment