Commit f332d7e1 authored by Neel Kant's avatar Neel Kant
Browse files

Rename fns to be more precise

parent ac79d374
...@@ -118,9 +118,9 @@ class HashedIndex(object): ...@@ -118,9 +118,9 @@ class HashedIndex(object):
def test_retriever(): def test_retriever():
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
model = load_checkpoint() model = load_ict_checkpoint()
model.eval() model.eval()
dataset = get_dataset() dataset = get_ict_dataset()
hashed_index = HashedIndex.load_from_file('block_hash_data.pkl') hashed_index = HashedIndex.load_from_file('block_hash_data.pkl')
retriever = REALMRetriever(model, dataset, hashed_index) retriever = REALMRetriever(model, dataset, hashed_index)
...@@ -151,12 +151,15 @@ def main(): ...@@ -151,12 +151,15 @@ def main():
# allocate the resources well. Have to subsequently assign the correct gpus to the indexing job # allocate the resources well. Have to subsequently assign the correct gpus to the indexing job
# consider initializing everything in a single group and break off processes based on the ranks # consider initializing everything in a single group and break off processes based on the ranks
# for debugging purposes, make it so that the training process group checks every some number of intervals
# and if it isn't ready, then wait so that it's consistent. Start with using the filesystem
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
model = load_checkpoint() model = load_ict_checkpoint()
model.eval() model.eval()
dataset = get_dataset() dataset = get_ict_dataset()
data_iter = iter(get_dataloader(dataset)) data_iter = iter(get_dataloader(dataset))
hashed_index = HashedIndex(embed_size=128, num_buckets=2048) hashed_index = HashedIndex(embed_size=128, num_buckets=2048)
...@@ -189,7 +192,7 @@ def main(): ...@@ -189,7 +192,7 @@ def main():
hashed_index.clear() hashed_index.clear()
def load_checkpoint(): def load_ict_checkpoint():
args = get_args() args = get_args()
model = get_model(model_provider) model = get_model(model_provider)
...@@ -215,7 +218,7 @@ def load_checkpoint(): ...@@ -215,7 +218,7 @@ def load_checkpoint():
return model return model
def get_dataset(): def get_ict_dataset():
args = get_args() args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True) titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment