Commit 371d2ea9 authored by Neel Kant's avatar Neel Kant
Browse files

Complete definition of ICTBertModel

parent 9873a8da
...@@ -284,4 +284,36 @@ class ICTBertModel(MegatronModule): ...@@ -284,4 +284,36 @@ class ICTBertModel(MegatronModule):
attention_softmax_in_fp32=attention_softmax_in_fp32) attention_softmax_in_fp32=attention_softmax_in_fp32)
self.question_model = BertModel(**bert_args) self.question_model = BertModel(**bert_args)
self.evidence_model = BertModel(**bert_args) self._question_key = 'question_model'
self.context_model = BertModel(**bert_args)
self._context_key = 'context_model'
def forward(self, input_tokens, input_attention_mask, input_types,
context_tokens, context_attention_mask, context_types):
question_ict_logits, _ = self.question_model.forward(input_tokens, input_attention_mask, input_types)
context_ict_logits, _ = self.context_model.forward(context_tokens, context_attention_mask, context_types)
# [batch x h] * [h x batch]
retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1))
return retrieval_scores
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
state_dict_ = {}
state_dict_[self._question_key] \
= self.question_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._context_key] \
= self.context_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.question_model.load_state_dict(
state_dict[self._question_key], strict=strict)
self.context_model.load_state_dict(
state_dict[self._context_key], strict=strict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment