Commit 33dc8e9c authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working when no interleaving

parent 9a8b89ac
......@@ -684,6 +684,10 @@ def _add_distributed_args(parser):
group.add_argument('--deallocate-pipeline-outputs', action='store_true',
default=False, help='If set, pipeline output tensors '
'are deallocated during the forward pass.')
group.add_argument('--standalone-embed-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers.')
return parser
......
......@@ -269,6 +269,9 @@ def set_tensor_model_parallel_world_size(world_size):
def set_pipeline_model_parallel_world_size(world_size):
# >>>
raise Exception("hi.")
# <<<
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
......@@ -287,6 +290,9 @@ def get_pipeline_model_parallel_world_size():
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
# >>>
# raise Exception("hi.")
# <<<
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
......@@ -322,6 +328,9 @@ def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
# >>>
raise Exception("fix for t5.")
# <<<
assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
......@@ -334,9 +343,27 @@ def get_num_layers(args, is_encoder_and_decoder_model):
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
assert args.num_layers % get_pipeline_model_parallel_world_size() == 0, \
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers = args.num_layers // get_pipeline_model_parallel_world_size()
transformer_pipeline_size = (
get_pipeline_model_parallel_world_size() - 1
if args.standalone_embed_stage else
get_pipeline_model_parallel_world_size()
)
assert args.num_layers % transformer_pipeline_size == 0, \
'num_layers must be divisible by transformer_pipeline_size'
num_layers = (
0
if args.standalone_embed_stage
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // transformer_pipeline_size
)
# >>>
# from lutil import pax
# pax({
# "rank" : torch.distributed.get_rank(),
# "pipeline rank" : get_pipeline_model_parallel_rank(),
# "num_layers" : num_layers,
# })
# <<<
else:
num_layers = args.num_layers
return num_layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment