Commit 2f6d2a3a authored by Neel Kant's avatar Neel Kant
Browse files

Fix ICTBertModel args

parent b03af49e
...@@ -215,44 +215,16 @@ class BertModel(MegatronModule): ...@@ -215,44 +215,16 @@ class BertModel(MegatronModule):
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
def __init__(self, def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
ict_head_size, ict_head_size,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0, num_tokentypes=0,
parallel_output=True, parallel_output=True):
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(ICTBertModel, self).__init__() super(ICTBertModel, self).__init__()
bert_args = dict( bert_args = dict(
num_layers=num_layers, num_tokentypes=num_tokentypes,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
checkpoint_activations=checkpoint_activations,
add_binary_head=False, add_binary_head=False,
ict_head_size=ict_head_size, ict_head_size=ict_head_size,
checkpoint_num_layers=checkpoint_num_layers, parallel_output=parallel_output
layernorm_epsilon=layernorm_epsilon, )
init_method_std=init_method_std,
num_tokentypes=num_tokentypes,
parallel_output=parallel_output,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
self.question_model = BertModel(**bert_args) self.question_model = BertModel(**bert_args)
self._question_key = 'question_model' self._question_key = 'question_model'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment