Commit cbf8250b authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

different encoder/decoder num-layers support

parent 41276b6c
......@@ -225,6 +225,13 @@ def validate_args(args, defaults={}):
'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples'
if args.num_layers is not None:
assert args.encoder_num_layers is None
args.encoder_num_layers = args.num_layers
else:
assert args.encoder_num_layers is not None
args.num_layers = args.encoder_num_layers
# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
......@@ -352,6 +359,10 @@ def _add_network_size_args(parser):
group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.')
group.add_argument('--encoder-num-layers', type=int, default=None,
help='Number of encoder transformer layers.')
group.add_argument('--decoder-num-layers', type=int, default=None,
help='Number of decoder transformer layers.')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
......
......@@ -746,7 +746,9 @@ class ParallelTransformer(MegatronModule):
# Number of layers.
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
args,
args.model_type == ModelType.encoder_and_decoder,
layer_type == LayerType.decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
......
......@@ -313,7 +313,7 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_num_layers(args, is_encoder_and_decoder_model):
def get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
......@@ -329,20 +329,21 @@ def get_num_layers(args, is_encoder_and_decoder_model):
args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder)
assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder)
assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split():
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // num_ranks_in_encoder
args.encoder_num_layers // num_ranks_in_encoder
)
else:
num_layers = args.num_layers // num_ranks_in_decoder
num_layers = args.decoder_num_layers // num_ranks_in_decoder
else:
assert args.num_layers == args.encoder_num_layers
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
......@@ -357,7 +358,10 @@ def get_num_layers(args, is_encoder_and_decoder_model):
args.num_layers // args.transformer_pipeline_model_parallel_size
)
else:
num_layers = args.num_layers
if not is_decoder:
num_layers = args.encoder_num_layers
else:
num_layers = args.decoder_num_layers
return num_layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment