Commit 3fb02b8e authored by Neel Kant's avatar Neel Kant
Browse files

HashedIndex.load_from_file

parent 1b44a4c4
......@@ -101,6 +101,17 @@ class HashedIndex(object):
self.block_data = defaultdict(list)
self.hash_data = defaultdict(list)
@classmethod
def load_from_file(cls, fname):
state_dict = pickle.load(open(fname, 'rb'))
hash_matrix = state_dict['hash_matrix']
new_index = HashedIndex(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
new_index.block_data = state_dict['block_data']
new_index.hash_data = state_dict['hash_data']
new_index.hash_matrix = hash_matrix
return new_index
def test_retriever():
initialize_megatron(extra_args_provider=None,
......@@ -108,7 +119,7 @@ def test_retriever():
model = load_checkpoint()
model.eval()
dataset = get_dataset()
hashed_index = HashedIndex(embed_size=128, num_buckets=2048)
hashed_index = HashedIndex.load_from_file('block_hash_data.pkl')
retriever = REALMRetriever(model, dataset, hashed_index)
retriever.retrieve_evidence_blocks_text("The last monarch from the house of windsor")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment