Commit 2de7ae27 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 't5_pipeline_fixes' into 'main'

Generic fix to T5 pipeline parallelism bug.

See merge request ADLR/megatron-lm!376
parents 9d86ca67 d439b417
...@@ -51,8 +51,7 @@ class MegatronModule(torch.nn.Module): ...@@ -51,8 +51,7 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if not mpu.is_pipeline_last_stage(ignore_virtual=True) or \ if self.pre_process:
mpu.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
else: else:
if not self.share_word_embeddings: if not self.share_word_embeddings:
...@@ -85,7 +84,8 @@ class MegatronModule(torch.nn.Module): ...@@ -85,7 +84,8 @@ class MegatronModule(torch.nn.Module):
# 3. In the training loop, before an all-reduce between the grads of # 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage() and \
not self.pre_process:
assert not mpu.is_pipeline_first_stage() assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
...@@ -99,8 +99,7 @@ class MegatronModule(torch.nn.Module): ...@@ -99,8 +99,7 @@ class MegatronModule(torch.nn.Module):
# Zero out initial weights for decoder embedding. # Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule. # NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \ if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
not mpu.is_pipeline_last_stage(ignore_virtual=True) and \ self.pre_process:
mpu.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters() self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
...@@ -109,21 +108,17 @@ class MegatronModule(torch.nn.Module): ...@@ -109,21 +108,17 @@ class MegatronModule(torch.nn.Module):
if mpu.is_rank_in_embedding_group(): if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder # Ensure that encoder(first stage) and decoder(split stage) position
# tensors of the right shape with all zeros. # embeddings have the same initial parameter values
# NOTE: We don't currently support T5 with the interleaved schedule. # NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None: if mpu.is_rank_in_position_embedding_group() and \
args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding. # TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda() self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data, torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_embedding_group()) group=mpu.get_position_embedding_group())
else: else:
print("WARNING! Distributed processes aren't initialized, so " print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. " "word embeddings in the last layer are not initialized. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment