Commit 17843605 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

pipeline_fixes

parent 945ece94
......@@ -334,10 +334,6 @@ class TransformerLanguageModel(MegatronModule):
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
......
......@@ -580,6 +580,7 @@ class ParallelTransformer(MegatronModule):
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
......@@ -596,7 +597,15 @@ class ParallelTransformer(MegatronModule):
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
if args.model_type == ModelType.encoder_and_decoder:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
......
......@@ -25,6 +25,7 @@ from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size
from .initialize import get_embedding_group
from .initialize import get_position_embedding_group
from .initialize import get_model_parallel_group
from .initialize import get_tensor_model_parallel_group
from .initialize import get_pipeline_model_parallel_group
......@@ -32,6 +33,7 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group
from .initialize import is_rank_in_position_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
......
......@@ -29,6 +29,8 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
......@@ -45,6 +47,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
......@@ -165,6 +170,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized'
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, \
'position embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size,
num_pipeline_model_parallel_groups)
......@@ -176,19 +185,31 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
if pipeline_model_parallel_split_rank_ is not None and \
ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1]]
position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank_ is not None:
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1]]
if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_]]
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(position_embedding_ranks)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
......@@ -234,6 +255,13 @@ def get_embedding_group():
return _EMBEDDING_GROUP
def get_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
assert _POSITION_EMBEDDING_GROUP is not None, \
'position embedding group is not initialized'
return _POSITION_EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
......@@ -352,6 +380,13 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return False
def is_rank_in_position_embedding_group():
"""Return true if current rank is in position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
......@@ -467,3 +502,5 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
......@@ -427,6 +427,15 @@ def train_step(forward_step_func, data_iterator,
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update parameters.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment