Commit a9f6569d authored by Jimmy Zhang's avatar Jimmy Zhang
Browse files

remove unneccessary argument check

parent 4fcb2f45
...@@ -70,8 +70,7 @@ class MegatronModule(torch.nn.Module): ...@@ -70,8 +70,7 @@ class MegatronModule(torch.nn.Module):
# 3. In the training loop, before an all-reduce between the grads of # 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage() and \ if mpu.is_pipeline_last_stage() and not self.pre_process:
not self.pre_process and not self.untie_embeddings_and_output_weights:
assert not mpu.is_pipeline_first_stage() assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment