Commit 5c8238c3 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Fix grad norm computation

parent da1c96e9
...@@ -164,9 +164,12 @@ class Embedding(MegatronModule): ...@@ -164,9 +164,12 @@ class Embedding(MegatronModule):
def zero_parameters(self): def zero_parameters(self):
"""Zero out all parameters in embedding.""" """Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0) self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0: if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0) self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes): def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add """Add token-type embedding. This function is provided so we can add
...@@ -331,10 +334,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -331,10 +334,6 @@ class TransformerLanguageModel(MegatronModule):
# Decoder (usually set to False, True if part of an encoder-decoder # Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage). # architecture and in decoder-only stage).
if self.add_decoder: if self.add_decoder:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment