Commit fd33e930 authored by Neel Kant's avatar Neel Kant
Browse files

Create ICTBertModel and update model/.__init__.py

parent bcb320ee
...@@ -14,6 +14,6 @@ ...@@ -14,6 +14,6 @@
# limitations under the License. # limitations under the License.
from .distributed import * from .distributed import *
from .bert_model import BertModel from .bert_model import BertModel, ICTBertModel
from .gpt2_model import GPT2Model from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization from .utils import get_params_for_weight_decay_optimization
...@@ -240,3 +240,47 @@ class BertModel(MegatronModule): ...@@ -240,3 +240,47 @@ class BertModel(MegatronModule):
self.ict_head.load_state_dict(state_dict[self._ict_head_key], self.ict_head.load_state_dict(state_dict[self._ict_head_key],
strict=strict) strict=strict)
class ICTBertModel(MegatronModule):
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
ict_head_size,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(ICTBertModel, self).__init__()
bert_args = dict(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
checkpoint_activations=checkpoint_activations,
add_binary_head=False,
ict_head_size=ict_head_size,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method_std=init_method_std,
num_tokentypes=num_tokentypes,
parallel_output=parallel_output,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
self.question_model = BertModel(**bert_args)
self.evidence_model = BertModel(**bert_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment