Commit 1b44a4c4 authored by Neel Kant's avatar Neel Kant
Browse files

add test_retriever

parent 88637044
......@@ -14,6 +14,7 @@ from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever
from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider
......@@ -101,6 +102,17 @@ class HashedIndex(object):
self.hash_data = defaultdict(list)
def test_retriever():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
model = load_checkpoint()
model.eval()
dataset = get_dataset()
hashed_index = HashedIndex(embed_size=128, num_buckets=2048)
retriever = REALMRetriever(model, dataset, hashed_index)
retriever.retrieve_evidence_blocks_text("The last monarch from the house of windsor")
def main():
# TODO
......
......@@ -14,6 +14,6 @@
# limitations under the License.
from .distributed import *
from .bert_model import BertModel, ICTBertModel, REALMBertModel
from .bert_model import BertModel, ICTBertModel, REALMBertModel, REALMRetriever
from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment