Commit 2fadaa50 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed t5 'get_num_layers()'

parent c04c4977
...@@ -339,7 +339,12 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -339,7 +339,12 @@ def get_num_layers(args, is_encoder_and_decoder_model):
assert args.num_layers % num_ranks_in_decoder == 0, \ assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder) 'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split(): if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // num_ranks_in_encoder
)
else: else:
num_layers = args.num_layers // num_ranks_in_decoder num_layers = args.num_layers // num_ranks_in_decoder
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment