Commit bcb320ee authored by Neel Kant's avatar Neel Kant
Browse files

Add ICT-related parameters to BertModel

parent 21a916b1
...@@ -74,7 +74,7 @@ class BertLMHead(MegatronModule): ...@@ -74,7 +74,7 @@ class BertLMHead(MegatronModule):
hidden_size: hidden size hidden_size: hidden size
init_method: init method for weight initialization init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not. parallel_output: whether output logits being distributed or not.
""" """
def __init__(self, mpu_vocab_size, hidden_size, init_method, def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output): layernorm_epsilon, parallel_output):
...@@ -118,6 +118,7 @@ class BertModel(MegatronModule): ...@@ -118,6 +118,7 @@ class BertModel(MegatronModule):
checkpoint_activations, checkpoint_activations,
checkpoint_num_layers=1, checkpoint_num_layers=1,
add_binary_head=False, add_binary_head=False,
ict_head_size=None,
layernorm_epsilon=1.0e-5, layernorm_epsilon=1.0e-5,
init_method_std=0.02, init_method_std=0.02,
num_tokentypes=0, num_tokentypes=0,
...@@ -128,8 +129,13 @@ class BertModel(MegatronModule): ...@@ -128,8 +129,13 @@ class BertModel(MegatronModule):
super(BertModel, self).__init__() super(BertModel, self).__init__()
self.add_binary_head = add_binary_head self.add_binary_head = add_binary_head
self.ict_head_size = ict_head_size
self.add_ict_head = ict_head_size is not None
assert not (self.add_binary_head and self.add_ict_head)
self.parallel_output = parallel_output self.parallel_output = parallel_output
init_method = init_method_normal(init_method_std) init_method = init_method_normal(init_method_std)
add_pooler = self.add_binary_head or self.add_ict_head
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers, num_layers=num_layers,
...@@ -141,7 +147,7 @@ class BertModel(MegatronModule): ...@@ -141,7 +147,7 @@ class BertModel(MegatronModule):
output_dropout_prob=output_dropout_prob, output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head, add_pooler=add_pooler,
attention_mask_func=bert_attention_mask_func, attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations, checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers, checkpoint_num_layers=checkpoint_num_layers,
...@@ -161,7 +167,9 @@ class BertModel(MegatronModule): ...@@ -161,7 +167,9 @@ class BertModel(MegatronModule):
if self.add_binary_head: if self.add_binary_head:
self.binary_head = get_linear_layer(hidden_size, 2, init_method) self.binary_head = get_linear_layer(hidden_size, 2, init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
elif self.add_ict_head:
self.ict_head = get_linear_layer(hidden_size, ict_head_size, init_method)
self._ict_head_key = 'ict_head'
def forward(self, input_ids, attention_mask, def forward(self, input_ids, attention_mask,
tokentype_ids=None): tokentype_ids=None):
...@@ -170,7 +178,7 @@ class BertModel(MegatronModule): ...@@ -170,7 +178,7 @@ class BertModel(MegatronModule):
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids) position_ids = bert_position_ids(input_ids)
if self.add_binary_head: if self.add_binary_head or self.add_ict_head:
lm_output, pooled_output = self.language_model( lm_output, pooled_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
...@@ -190,6 +198,9 @@ class BertModel(MegatronModule): ...@@ -190,6 +198,9 @@ class BertModel(MegatronModule):
if self.add_binary_head: if self.add_binary_head:
binary_logits = self.binary_head(pooled_output) binary_logits = self.binary_head(pooled_output)
return lm_logits, binary_logits return lm_logits, binary_logits
elif self.add_ict_head:
ict_logits = self.ict_head(pooled_output)
return lm_logits, ict_logits
return lm_logits, None return lm_logits, None
...@@ -209,6 +220,9 @@ class BertModel(MegatronModule): ...@@ -209,6 +220,9 @@ class BertModel(MegatronModule):
if self.add_binary_head: if self.add_binary_head:
state_dict_[self._binary_head_key] \ state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars) = self.binary_head.state_dict(destination, prefix, keep_vars)
elif self.add_ict_head:
state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -222,3 +236,7 @@ class BertModel(MegatronModule): ...@@ -222,3 +236,7 @@ class BertModel(MegatronModule):
if self.add_binary_head: if self.add_binary_head:
self.binary_head.load_state_dict(state_dict[self._binary_head_key], self.binary_head.load_state_dict(state_dict[self._binary_head_key],
strict=strict) strict=strict)
elif self.add_ict_head:
self.ict_head.load_state_dict(state_dict[self._ict_head_key],
strict=strict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment