"vscode:/vscode.git/clone" did not exist on "72c77763559317b2c8bddfd67e173b67aa1facb0"
Commit 3af6725d authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working for t5 [ encoder embedding only ]

parent 1fa6990c
......@@ -329,16 +329,35 @@ def get_num_layers(args, is_encoder_and_decoder_model):
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
# >>>
raise Exception("fix for t5.")
# raise Exception("fix for t5.")
# <<<
assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
# >>>
# num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
# +++
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embed_stage else
args.pipeline_model_parallel_split_rank
)
# <<<
# >>>
# num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
# +++
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
# <<<
# >>>
# raise Exception(">>>> standalone %d, encoder %d, decoder %d. <<<<" % (
# args.standalone_embed_stage,
# num_ranks_in_encoder,
# num_ranks_in_decoder,
# ))
# <<<
assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers must be divisible by number of ranks given to encoder'
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder)
assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers must be divisible by number of ranks given to decoder'
if is_pipeline_stage_before_split():
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split(): # args):
num_layers = args.num_layers // num_ranks_in_encoder
else:
num_layers = args.num_layers // num_ranks_in_decoder
......@@ -419,6 +438,9 @@ def is_rank_in_position_embedding_group():
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
# >>>
# def is_pipeline_stage_before_split(args, rank=None):
# <<<
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
......@@ -426,6 +448,11 @@ def is_pipeline_stage_before_split(rank=None):
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
# <<<
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
......@@ -434,6 +461,9 @@ def is_pipeline_stage_before_split(rank=None):
return False
# >>>
# def is_pipeline_stage_after_split(args, rank=None):
# <<<
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
......@@ -441,6 +471,11 @@ def is_pipeline_stage_after_split(rank=None):
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
# >>>
# if args.standalone_embed_stage:
# rank += 1
assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
# <<<
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
......
......@@ -254,8 +254,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or (
rank == (world_size - 1))
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
# >>>
add_encoder = mpu.is_pipeline_stage_before_split() # args)
add_decoder = mpu.is_pipeline_stage_after_split() # args)
# <<<
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment