Commit 69a546be authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Small bugfix in bert_model.py: make sure word_embeddings is initialized before...

Small bugfix in bert_model.py: make sure word_embeddings is initialized before instantiating lm_head
parent 1979c242
...@@ -149,6 +149,7 @@ class BertModelBase(PipelinedMegatronModule): ...@@ -149,6 +149,7 @@ class BertModelBase(PipelinedMegatronModule):
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
self.initialize_word_embeddings(init_method_normal)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
self.lm_head = BertLMHead( self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0), self.word_embeddings_weight().size(0),
...@@ -160,8 +161,6 @@ class BertModelBase(PipelinedMegatronModule): ...@@ -160,8 +161,6 @@ class BertModelBase(PipelinedMegatronModule):
init_method) init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
self.initialize_word_embeddings(init_method_normal)
def forward(self, bert_model_input, attention_mask, def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None): tokentype_ids=None, lm_labels=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment