Commit c044f59a authored by Neel Kant's avatar Neel Kant
Browse files

Some changes noticed late

parent 32bb4edc
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
from .distributed import * from .distributed import *
from .bert_model import BertModel from .bert_model import BertModel
from megatron.model.realm_model import ICTBertModel, REALMRetriever, REALMBertModel from megatron.model.realm_model import ICTBertModel
from .gpt2_model import GPT2Model from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization from .utils import get_params_for_weight_decay_optimization
...@@ -42,14 +42,8 @@ class ICTBertModel(MegatronModule): ...@@ -42,14 +42,8 @@ class ICTBertModel(MegatronModule):
self.block_model = BertModel(**bert_args) self.block_model = BertModel(**bert_args)
self._block_key = 'context_model' self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False): def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
"""Run a forward pass for each of the models and compute the similarity scores.""" """Run a forward pass for each of the models and return the respective embeddings."""
if only_query:
return self.embed_query(query_tokens, query_attention_mask)
if only_block:
return self.embed_block(block_tokens, block_attention_mask)
query_logits = self.embed_query(query_tokens, query_attention_mask) query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask) block_logits = self.embed_block(block_tokens, block_attention_mask)
return query_logits, block_logits return query_logits, block_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment