Commit 9873a8da authored by Neel Kant's avatar Neel Kant
Browse files

Reformat parts of BertModel

parent fd33e930
......@@ -159,10 +159,11 @@ class BertModel(MegatronModule):
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
hidden_size, init_method, layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
if not self.add_ict_head:
self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
hidden_size, init_method, layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
if self.add_binary_head:
self.binary_head = get_linear_layer(hidden_size, 2, init_method)
......@@ -192,15 +193,15 @@ class BertModel(MegatronModule):
tokentype_ids=tokentype_ids)
# Output.
if self.add_ict_head:
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
lm_logits = self.lm_head(
lm_output, self.language_model.embedding.word_embeddings.weight)
if self.add_binary_head:
binary_logits = self.binary_head(pooled_output)
return lm_logits, binary_logits
elif self.add_ict_head:
ict_logits = self.ict_head(pooled_output)
return lm_logits, ict_logits
return lm_logits, None
......@@ -231,14 +232,14 @@ class BertModel(MegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if self.add_binary_head:
self.binary_head.load_state_dict(state_dict[self._binary_head_key],
strict=strict)
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
elif self.add_ict_head:
self.ict_head.load_state_dict(state_dict[self._ict_head_key],
strict=strict)
self.ict_head.load_state_dict(
state_dict[self._ict_head_key], strict=strict)
class ICTBertModel(MegatronModule):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment