Commit f332d7e1 authored by Neel Kant's avatar Neel Kant
Browse files

Rename fns to be more precise

parent ac79d374
......@@ -118,9 +118,9 @@ class HashedIndex(object):
def test_retriever():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
model = load_checkpoint()
model = load_ict_checkpoint()
model.eval()
dataset = get_dataset()
dataset = get_ict_dataset()
hashed_index = HashedIndex.load_from_file('block_hash_data.pkl')
retriever = REALMRetriever(model, dataset, hashed_index)
......@@ -151,12 +151,15 @@ def main():
# allocate the resources well. Have to subsequently assign the correct gpus to the indexing job
# consider initializing everything in a single group and break off processes based on the ranks
# for debugging purposes, make it so that the training process group checks every some number of intervals
# and if it isn't ready, then wait so that it's consistent. Start with using the filesystem
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
model = load_checkpoint()
model = load_ict_checkpoint()
model.eval()
dataset = get_dataset()
dataset = get_ict_dataset()
data_iter = iter(get_dataloader(dataset))
hashed_index = HashedIndex(embed_size=128, num_buckets=2048)
......@@ -189,7 +192,7 @@ def main():
hashed_index.clear()
def load_checkpoint():
def load_ict_checkpoint():
args = get_args()
model = get_model(model_provider)
......@@ -215,7 +218,7 @@ def load_checkpoint():
return model
def get_dataset():
def get_ict_dataset():
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment