Commit 9da6e975 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

simplify code

parent ade99d61
...@@ -51,8 +51,7 @@ class MegatronModule(torch.nn.Module): ...@@ -51,8 +51,7 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if hasattr(self.language_model, 'embedding') and \ if self.pre_process:
self.language_model.embedding is not None:
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
else: else:
if not self.share_word_embeddings: if not self.share_word_embeddings:
...@@ -86,8 +85,7 @@ class MegatronModule(torch.nn.Module): ...@@ -86,8 +85,7 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage() and \ if mpu.is_pipeline_last_stage() and \
(not hasattr(self.language_model, 'embedding') or not self.pre_process:
self.language_model.embedding is None):
assert not mpu.is_pipeline_first_stage() assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
...@@ -101,9 +99,7 @@ class MegatronModule(torch.nn.Module): ...@@ -101,9 +99,7 @@ class MegatronModule(torch.nn.Module):
# Zero out initial weights for decoder embedding. # Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule. # NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \ if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
mpu.is_rank_in_embedding_group() and \ self.pre_process:
hasattr(self.language_model, 'embedding') and \
self.language_model.embedding is not None:
self.language_model.embedding.zero_parameters() self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment