"vscode:/vscode.git/clone" did not exist on "757ecfb0adf2df031a4570430eaa9085fe20a5c0"
Commit c044f59a authored by Neel Kant's avatar Neel Kant
Browse files

Some changes noticed late

parent 32bb4edc
......@@ -15,6 +15,6 @@
from .distributed import *
from .bert_model import BertModel
from megatron.model.realm_model import ICTBertModel, REALMRetriever, REALMBertModel
from megatron.model.realm_model import ICTBertModel
from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization
......@@ -42,14 +42,8 @@ class ICTBertModel(MegatronModule):
self.block_model = BertModel(**bert_args)
self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
"""Run a forward pass for each of the models and compute the similarity scores."""
if only_query:
return self.embed_query(query_tokens, query_attention_mask)
if only_block:
return self.embed_block(block_tokens, block_attention_mask)
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
"""Run a forward pass for each of the models and return the respective embeddings."""
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
return query_logits, block_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment