Commit 17843605 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

pipeline_fixes

parent 945ece94
...@@ -334,10 +334,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -334,10 +334,6 @@ class TransformerLanguageModel(MegatronModule):
# Decoder (usually set to False, True if part of an encoder-decoder # Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage). # architecture and in decoder-only stage).
if self.add_decoder: if self.add_decoder:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
......
...@@ -580,6 +580,7 @@ class ParallelTransformer(MegatronModule): ...@@ -580,6 +580,7 @@ class ParallelTransformer(MegatronModule):
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \ 'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size' 'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage, # Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage. # divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
...@@ -596,7 +597,15 @@ class ParallelTransformer(MegatronModule): ...@@ -596,7 +597,15 @@ class ParallelTransformer(MegatronModule):
(mpu.get_pipeline_model_parallel_rank() * self.num_layers) (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else: else:
# Each stage gets a contiguous set of layers. # Each stage gets a contiguous set of layers.
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers if args.model_type == ModelType.encoder_and_decoder:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
......
...@@ -25,6 +25,7 @@ from .initialize import get_data_parallel_group ...@@ -25,6 +25,7 @@ from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size from .initialize import get_data_parallel_world_size
from .initialize import get_embedding_group from .initialize import get_embedding_group
from .initialize import get_position_embedding_group
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_tensor_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_pipeline_model_parallel_group from .initialize import get_pipeline_model_parallel_group
...@@ -32,6 +33,7 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle ...@@ -32,6 +33,7 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group from .initialize import is_rank_in_embedding_group
from .initialize import is_rank_in_position_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers from .initialize import get_num_layers
......
...@@ -29,6 +29,8 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None ...@@ -29,6 +29,8 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP = None
# Embedding group. # Embedding group.
_EMBEDDING_GROUP = None _EMBEDDING_GROUP = None
# Position embedding group.
_POSITION EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to. # Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
...@@ -45,6 +47,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None ...@@ -45,6 +47,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding. # A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None _EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source # A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage. # rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
...@@ -165,6 +170,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -165,6 +170,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
global _EMBEDDING_GLOBAL_RANKS global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \ assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized' 'embedding group is already initialized'
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, \
'position embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, ranks = range(i, world_size,
num_pipeline_model_parallel_groups) num_pipeline_model_parallel_groups)
...@@ -176,19 +185,31 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -176,19 +185,31 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]] embedding_ranks = [ranks[0], ranks[-1]]
if pipeline_model_parallel_split_rank_ is not None and \ position_embedding_ranks = [ranks[0]]
ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks: if pipeline_model_parallel_split_rank_ is not None:
embedding_ranks = [ranks[0], if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
ranks[pipeline_model_parallel_split_rank_], embedding_ranks = [ranks[0],
ranks[-1]] ranks[pipeline_model_parallel_split_rank_],
ranks[-1]]
if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_]]
else: else:
embedding_ranks = ranks embedding_ranks = ranks
position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks) group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks: if rank in embedding_ranks:
_EMBEDDING_GROUP = group _EMBEDDING_GROUP = group
if rank in ranks: if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks _EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(position_embedding_ranks)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
...@@ -234,6 +255,13 @@ def get_embedding_group(): ...@@ -234,6 +255,13 @@ def get_embedding_group():
return _EMBEDDING_GROUP return _EMBEDDING_GROUP
def get_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
assert _POSITION_EMBEDDING_GROUP is not None, \
'position embedding group is not initialized'
return _POSITION_EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size): def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size""" """Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...@@ -352,6 +380,13 @@ def is_rank_in_embedding_group(ignore_virtual=False): ...@@ -352,6 +380,13 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return False return False
def is_rank_in_position_embedding_group():
"""Return true if current rank is in position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
def is_pipeline_stage_before_split(rank=None): def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model """Return True if pipeline stage executes encoder block for a model
with both encoder and decoder.""" with both encoder and decoder."""
...@@ -467,3 +502,5 @@ def destroy_model_parallel(): ...@@ -467,3 +502,5 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None _EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
...@@ -427,6 +427,15 @@ def train_step(forward_step_func, data_iterator, ...@@ -427,6 +427,15 @@ def train_step(forward_step_func, data_iterator,
else: else:
grad = word_embeddings_weight.grad grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
# Update parameters. # Update parameters.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment