Commit 1a26b291 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core-fix' into 'main'

Fix merge error.

See merge request ADLR/megatron-lm!478
parents 42c40715 e0a12fe1
...@@ -738,7 +738,7 @@ class NoopTransformerLayer(MegatronModule): ...@@ -738,7 +738,7 @@ class NoopTransformerLayer(MegatronModule):
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False): def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank.""" """Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model: if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None assert args.pipeline_model_parallel_split_rank is not None
...@@ -756,11 +756,11 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False): ...@@ -756,11 +756,11 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder) 'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.decoder_num_layers % num_ranks_in_decoder == 0, \ assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder) 'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split(): if mpu.is_pipeline_stage_before_split():
num_layers = ( num_layers = (
0 0
if args.standalone_embedding_stage if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else and mpu.get_pipeline_model_parallel_rank() == 0 else
args.encoder_num_layers // num_ranks_in_encoder args.encoder_num_layers // num_ranks_in_encoder
) )
else: else:
...@@ -777,7 +777,7 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False): ...@@ -777,7 +777,7 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
num_layers = ( num_layers = (
0 0
if args.standalone_embedding_stage if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else and mpu.get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size args.num_layers // args.transformer_pipeline_model_parallel_size
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment