Commit 69f3249f authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent cbf8250b
...@@ -226,10 +226,12 @@ def validate_args(args, defaults={}): ...@@ -226,10 +226,12 @@ def validate_args(args, defaults={}):
'and lr-warmup-samples' 'and lr-warmup-samples'
if args.num_layers is not None: if args.num_layers is not None:
assert args.encoder_num_layers is None assert args.encoder_num_layers is None, \
'cannot have both num-layers and encoder-num-layers specified'
args.encoder_num_layers = args.num_layers args.encoder_num_layers = args.num_layers
else: else:
assert args.encoder_num_layers is not None assert args.encoder_num_layers is not None, \
'either num-layers or encoder-num-layers should be specified'
args.num_layers = args.encoder_num_layers args.num_layers = args.encoder_num_layers
# Check required arguments. # Check required arguments.
......
...@@ -330,9 +330,9 @@ def get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False): ...@@ -330,9 +330,9 @@ def get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
) )
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.encoder_num_layers % num_ranks_in_encoder == 0, \ assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder) 'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.decoder_num_layers % num_ranks_in_decoder == 0, \ assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder) 'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split(): if is_pipeline_stage_before_split():
num_layers = ( num_layers = (
0 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment