Commit cbf8250b authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

different encoder/decoder num-layers support

parent 41276b6c
...@@ -225,6 +225,13 @@ def validate_args(args, defaults={}): ...@@ -225,6 +225,13 @@ def validate_args(args, defaults={}):
'can only specify one of lr-warmup-fraction ' \ 'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples' 'and lr-warmup-samples'
if args.num_layers is not None:
assert args.encoder_num_layers is None
args.encoder_num_layers = args.num_layers
else:
assert args.encoder_num_layers is not None
args.num_layers = args.encoder_num_layers
# Check required arguments. # Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads', required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings'] 'max_position_embeddings']
...@@ -352,6 +359,10 @@ def _add_network_size_args(parser): ...@@ -352,6 +359,10 @@ def _add_network_size_args(parser):
group.add_argument('--num-layers', type=int, default=None, group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.') help='Number of transformer layers.')
group.add_argument('--encoder-num-layers', type=int, default=None,
help='Number of encoder transformer layers.')
group.add_argument('--decoder-num-layers', type=int, default=None,
help='Number of decoder transformer layers.')
group.add_argument('--hidden-size', type=int, default=None, group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.') help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None, group.add_argument('--ffn-hidden-size', type=int, default=None,
......
...@@ -746,7 +746,9 @@ class ParallelTransformer(MegatronModule): ...@@ -746,7 +746,9 @@ class ParallelTransformer(MegatronModule):
# Number of layers. # Number of layers.
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args,
args.model_type == ModelType.encoder_and_decoder,
layer_type == LayerType.decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
......
...@@ -313,7 +313,7 @@ def get_pipeline_model_parallel_rank(): ...@@ -313,7 +313,7 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_num_layers(args, is_encoder_and_decoder_model): def get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank.""" """Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1: if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model: if is_encoder_and_decoder_model:
...@@ -329,20 +329,21 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -329,20 +329,21 @@ def get_num_layers(args, is_encoder_and_decoder_model):
args.pipeline_model_parallel_split_rank args.pipeline_model_parallel_split_rank
) )
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.num_layers % num_ranks_in_encoder == 0, \ assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder) 'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.num_layers % num_ranks_in_decoder == 0, \ assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder) 'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split(): if is_pipeline_stage_before_split():
num_layers = ( num_layers = (
0 0
if args.standalone_embedding_stage if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // num_ranks_in_encoder args.encoder_num_layers // num_ranks_in_encoder
) )
else: else:
num_layers = args.num_layers // num_ranks_in_decoder num_layers = args.decoder_num_layers // num_ranks_in_decoder
else: else:
assert args.num_layers == args.encoder_num_layers
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size' 'num_layers must be divisible by transformer_pipeline_model_parallel_size'
...@@ -357,7 +358,10 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -357,7 +358,10 @@ def get_num_layers(args, is_encoder_and_decoder_model):
args.num_layers // args.transformer_pipeline_model_parallel_size args.num_layers // args.transformer_pipeline_model_parallel_size
) )
else: else:
num_layers = args.num_layers if not is_decoder:
num_layers = args.encoder_num_layers
else:
num_layers = args.decoder_num_layers
return num_layers return num_layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment