"doc/vscode:/vscode.git/clone" did not exist on "0e4b7a3929e12d1645e3e177148d15cd4cdec793"
Commit 26ea8314 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

fixing the corner case pp=2

parent 81ad628e
...@@ -51,8 +51,8 @@ class MegatronModule(torch.nn.Module): ...@@ -51,8 +51,8 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if not mpu.is_pipeline_last_stage(ignore_virtual=True) or \ if hasattr(self.language_model, 'embedding') and \
mpu.get_pipeline_model_parallel_world_size() == 1: self.language_model.embedding is not None:
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
else: else:
if not self.share_word_embeddings: if not self.share_word_embeddings:
...@@ -99,8 +99,9 @@ class MegatronModule(torch.nn.Module): ...@@ -99,8 +99,9 @@ class MegatronModule(torch.nn.Module):
# Zero out initial weights for decoder embedding. # Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule. # NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \ if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
not mpu.is_pipeline_last_stage(ignore_virtual=True) and \ mpu.is_rank_in_embedding_group() and \
mpu.is_rank_in_embedding_group(): hasattr(self.language_model, 'embedding') and \
self.language_model.embedding is not None:
self.language_model.embedding.zero_parameters() self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
...@@ -109,21 +110,18 @@ class MegatronModule(torch.nn.Module): ...@@ -109,21 +110,18 @@ class MegatronModule(torch.nn.Module):
if mpu.is_rank_in_embedding_group(): if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder # All-reduce other embeddings as well as necessary. The last stage
# tensors of the right shape with all zeros. # does not have these other embeddings, so just create placeholder
# NOTE: We don't currently support T5 with the interleaved schedule. # tensors of the right shape with all zeros.
if args.pipeline_model_parallel_split_rank is not None: # NOTE: We don't currently support T5 with the interleaved schedule.
# TODO: Support tokentype embedding. if mpu.is_rank_in_position_embedding_group() and \
dimensions = (args.max_position_embeddings, args.hidden_size) args.pipeline_model_parallel_split_rank is not None:
if mpu.is_pipeline_last_stage(ignore_virtual=True): # TODO: Support tokentype embedding.
position_embeddings = torch.nn.Embedding(*dimensions).cuda() self.language_model.embedding.cuda()
position_embeddings.weight.data.fill_(0) position_embeddings = self.language_model.embedding.position_embeddings
else: torch.distributed.all_reduce(position_embeddings.weight.data,
self.language_model.embedding.cuda() group=mpu.get_position_embedding_group())
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_embedding_group())
else: else:
print("WARNING! Distributed processes aren't initialized, so " print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. " "word embeddings in the last layer are not initialized. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment