Commit f2bf5a56 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

minor fixes

parent 17843605
......@@ -597,7 +597,8 @@ class ParallelTransformer(MegatronModule):
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder:
if args.model_type == ModelType.encoder_and_decoder and \
mpu.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
......
......@@ -30,7 +30,7 @@ _MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION EMBEDDING_GROUP = None
_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
......@@ -208,7 +208,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = embedding_ranks
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
def model_parallel_is_initialized():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment