Commit bcb320ee authored by Neel Kant's avatar Neel Kant
Browse files

Add ICT-related parameters to BertModel

parent 21a916b1
......@@ -74,7 +74,7 @@ class BertLMHead(MegatronModule):
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
......@@ -118,6 +118,7 @@ class BertModel(MegatronModule):
checkpoint_activations,
checkpoint_num_layers=1,
add_binary_head=False,
ict_head_size=None,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
......@@ -128,8 +129,13 @@ class BertModel(MegatronModule):
super(BertModel, self).__init__()
self.add_binary_head = add_binary_head
self.ict_head_size = ict_head_size
self.add_ict_head = ict_head_size is not None
assert not (self.add_binary_head and self.add_ict_head)
self.parallel_output = parallel_output
init_method = init_method_normal(init_method_std)
add_pooler = self.add_binary_head or self.add_ict_head
self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers,
......@@ -141,7 +147,7 @@ class BertModel(MegatronModule):
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
add_pooler=add_pooler,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
......@@ -161,7 +167,9 @@ class BertModel(MegatronModule):
if self.add_binary_head:
self.binary_head = get_linear_layer(hidden_size, 2, init_method)
self._binary_head_key = 'binary_head'
elif self.add_ict_head:
self.ict_head = get_linear_layer(hidden_size, ict_head_size, init_method)
self._ict_head_key = 'ict_head'
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
......@@ -170,7 +178,7 @@ class BertModel(MegatronModule):
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
if self.add_binary_head:
if self.add_binary_head or self.add_ict_head:
lm_output, pooled_output = self.language_model(
input_ids,
position_ids,
......@@ -190,6 +198,9 @@ class BertModel(MegatronModule):
if self.add_binary_head:
binary_logits = self.binary_head(pooled_output)
return lm_logits, binary_logits
elif self.add_ict_head:
ict_logits = self.ict_head(pooled_output)
return lm_logits, ict_logits
return lm_logits, None
......@@ -209,6 +220,9 @@ class BertModel(MegatronModule):
if self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
elif self.add_ict_head:
state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_
......@@ -222,3 +236,7 @@ class BertModel(MegatronModule):
if self.add_binary_head:
self.binary_head.load_state_dict(state_dict[self._binary_head_key],
strict=strict)
elif self.add_ict_head:
self.ict_head.load_state_dict(state_dict[self._ict_head_key],
strict=strict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment