Commit 69a546be authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Small bugfix in bert_model.py: make sure word_embeddings is initialized before...

Small bugfix in bert_model.py: make sure word_embeddings is initialized before instantiating lm_head
parent 1979c242
......@@ -149,6 +149,7 @@ class BertModelBase(PipelinedMegatronModule):
init_method=init_method,
scaled_init_method=scaled_init_method)
self.initialize_word_embeddings(init_method_normal)
if mpu.is_pipeline_last_stage():
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
......@@ -160,8 +161,6 @@ class BertModelBase(PipelinedMegatronModule):
init_method)
self._binary_head_key = 'binary_head'
self.initialize_word_embeddings(init_method_normal)
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment