Commit 59031aa7 authored by Neel Kant's avatar Neel Kant
Browse files

more for pretrain_realm

parent 002cb170
......@@ -284,10 +284,11 @@ class REALMBertModel(MegatronModule):
class REALMRetriever(MegatronModule):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def __init__(self, ict_model, ict_dataset, hashed_index, top_k=5):
def __init__(self, ict_model, ict_dataset, block_data, hashed_index, top_k=5):
super(REALMRetriever, self).__init__()
self.ict_model = ict_model
self.ict_dataset = ict_dataset
self.block_data = block_data
self.hashed_index = hashed_index
self.top_k = top_k
......@@ -320,8 +321,8 @@ class REALMRetriever(MegatronModule):
block_buckets[j] = block_buckets[i].copy()
# [batch_size x max_bucket_population x embed_size]
block_embeds = [torch.cuda.FloatTensor(np.array([self.hashed_index.get_block_embed(arr[3])
for arr in bucket])) for bucket in block_buckets]
block_embeds = [torch.cuda.FloatTensor(np.array([self.block_data.embed_data[idx]
for idx in bucket])) for bucket in block_buckets]
all_top5_tokens, all_top5_pad_masks = [], []
for query_embed, embed_tensor, bucket in zip(query_embeds, block_embeds, block_buckets):
......
......@@ -19,7 +19,7 @@ import torch
import torch.nn.functional as F
from hashed_index import load_ict_checkpoint, get_ict_dataset
from megatron.data.realm_dataset import HashedIndex
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex
from megatron import get_args
from megatron import get_timers
from megatron import mpu
......@@ -39,9 +39,10 @@ def model_provider():
ict_model = load_ict_checkpoint()
ict_dataset = get_ict_dataset()
hashed_index = HashedIndex.load_from_file(args.hash_data_path)
all_block_data = BlockData.load_from_file(args.block_data_path)
hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
retriever = REALMRetriever(ict_model, ict_dataset, hashed_index)
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
model = REALMBertModel(retriever)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment