"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "01a03caa208ba24cf21e2c131d9c596eaf731660"
Commit 9873a8da authored by Neel Kant's avatar Neel Kant
Browse files

Reformat parts of BertModel

parent fd33e930
...@@ -159,10 +159,11 @@ class BertModel(MegatronModule): ...@@ -159,10 +159,11 @@ class BertModel(MegatronModule):
apply_query_key_layer_scaling=apply_query_key_layer_scaling, apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32) attention_softmax_in_fp32=attention_softmax_in_fp32)
self.lm_head = BertLMHead( if not self.add_ict_head:
self.language_model.embedding.word_embeddings.weight.size(0), self.lm_head = BertLMHead(
hidden_size, init_method, layernorm_epsilon, parallel_output) self.language_model.embedding.word_embeddings.weight.size(0),
self._lm_head_key = 'lm_head' hidden_size, init_method, layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
if self.add_binary_head: if self.add_binary_head:
self.binary_head = get_linear_layer(hidden_size, 2, init_method) self.binary_head = get_linear_layer(hidden_size, 2, init_method)
...@@ -192,15 +193,15 @@ class BertModel(MegatronModule): ...@@ -192,15 +193,15 @@ class BertModel(MegatronModule):
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
# Output. # Output.
if self.add_ict_head:
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
lm_logits = self.lm_head( lm_logits = self.lm_head(
lm_output, self.language_model.embedding.word_embeddings.weight) lm_output, self.language_model.embedding.word_embeddings.weight)
if self.add_binary_head: if self.add_binary_head:
binary_logits = self.binary_head(pooled_output) binary_logits = self.binary_head(pooled_output)
return lm_logits, binary_logits return lm_logits, binary_logits
elif self.add_ict_head:
ict_logits = self.ict_head(pooled_output)
return lm_logits, ict_logits
return lm_logits, None return lm_logits, None
...@@ -231,14 +232,14 @@ class BertModel(MegatronModule): ...@@ -231,14 +232,14 @@ class BertModel(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key], self.lm_head.load_state_dict(
strict=strict) state_dict[self._lm_head_key], strict=strict)
if self.add_binary_head: if self.add_binary_head:
self.binary_head.load_state_dict(state_dict[self._binary_head_key], self.binary_head.load_state_dict(
strict=strict) state_dict[self._binary_head_key], strict=strict)
elif self.add_ict_head: elif self.add_ict_head:
self.ict_head.load_state_dict(state_dict[self._ict_head_key], self.ict_head.load_state_dict(
strict=strict) state_dict[self._ict_head_key], strict=strict)
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment