Commit 8b355d93 authored by Neel Kant's avatar Neel Kant
Browse files

Temporarily modify ict_eval_bm25 to build full wiki index

parent 8409e1c1
import lucene import lucene
import sys
from java.nio.file import Paths from java.nio.file import Paths
from org.apache.lucene.analysis.standard import StandardAnalyzer from org.apache.lucene.analysis.standard import StandardAnalyzer
...@@ -24,14 +25,19 @@ def setup(): ...@@ -24,14 +25,19 @@ def setup():
lucene.initVM(vmargs=['-Djava.awt.headless=true']) lucene.initVM(vmargs=['-Djava.awt.headless=true'])
def run(): def run(embed_all=False):
dset = get_ict_dataset(use_titles=False, query_in_block_prob=0.1) dset = get_ict_dataset(use_titles=False, query_in_block_prob=0.1)
dataloader = iter(get_one_epoch_dataloader(dset)) dataloader = iter(get_one_epoch_dataloader(dset))
index_dir = SimpleFSDirectory(Paths.get("index/")) index_dir = SimpleFSDirectory(Paths.get("full_wiki_index/"))
analyzer = StandardAnalyzer() analyzer = StandardAnalyzer()
analyzer.setMaxTokenLength(1024) analyzer.setMaxTokenLength(1024)
config = IndexWriterConfig(analyzer)
config.setOpenMode(IndexWriterConfig.OpenMode.CREATE)
writer = IndexWriter(index_dir, config)
# field for document ID # field for document ID
t1 = FieldType() t1 = FieldType()
t1.setStored(True) t1.setStored(True)
...@@ -46,7 +52,7 @@ def run(): ...@@ -46,7 +52,7 @@ def run():
correct = total = 0 correct = total = 0
round_correct = torch.zeros(1).cuda() round_correct = torch.zeros(1).cuda()
round_total = torch.zeros(1).cuda() round_total = torch.zeros(1).cuda()
for round in range(100): for round in range(100000):
with torch.no_grad(): with torch.no_grad():
try: try:
query_tokens, query_pad_mask, \ query_tokens, query_pad_mask, \
...@@ -54,19 +60,12 @@ def run(): ...@@ -54,19 +60,12 @@ def run():
except: except:
break break
query_tokens = query_tokens.detach().cpu().numpy() # query_tokens = query_tokens.detach().cpu().numpy()
block_tokens = block_tokens.detach().cpu().numpy() block_tokens = block_tokens.detach().cpu().numpy()
query_strs = [dset.decode_tokens(query_tokens[i].tolist(), hardcore=True) for i in range(query_tokens.shape[0])] # query_strs = [dset.decode_tokens(query_tokens[i].tolist(), hardcore=True) for i in range(query_tokens.shape[0])]
block_strs = [dset.decode_tokens(block_tokens[i].tolist(), hardcore=True) for i in range(block_tokens.shape[0])] block_strs = [dset.decode_tokens(block_tokens[i].tolist(), hardcore=True) for i in range(block_tokens.shape[0])]
# create index writer
config = IndexWriterConfig(analyzer)
config.setOpenMode(IndexWriterConfig.OpenMode.CREATE)
writer = IndexWriter(index_dir, config)
def add_document(text, writer, doc_id): def add_document(text, writer, doc_id):
doc = Document() doc = Document()
doc.add(Field("text", text, t2)) doc.add(Field("text", text, t2))
...@@ -79,32 +78,36 @@ def run(): ...@@ -79,32 +78,36 @@ def run():
# write and finalize the index # write and finalize the index
writer.commit() writer.commit()
writer.close()
# define BM25 searcher # define BM25 searcher
searcher = IndexSearcher(DirectoryReader.open(index_dir)) # searcher = IndexSearcher(DirectoryReader.open(index_dir))
searcher.setSimilarity(BM25Similarity()) # searcher.setSimilarity(BM25Similarity())
# feed queries and get scores for everything in the index # # feed queries and get scores for everything in the index
hits_list = [] # hits_list = []
for s in query_strs: # for s in query_strs:
query = QueryParser("text", analyzer).parse(s) # query = QueryParser("text", analyzer).parse(s)
hits = searcher.search(query, 8).scoreDocs # hits = searcher.search(query, 1).scoreDocs
hits_list.append(hits) # hits_list.append(hits)
for (i, hits) in enumerate(hits_list): # for (i, hits) in enumerate(hits_list):
doc_ids = [int(searcher.doc(hit.doc)['doc_id']) for hit in hits] # doc_ids = [int(searcher.doc(hit.doc)['doc_id']) for hit in hits]
correct += int(i in doc_ids) # correct += int(i in doc_ids)
total += 1 # total += 1
dist.all_reduce(round_correct) # dist.all_reduce(round_correct)
dist.all_reduce(round_total) # dist.all_reduce(round_total)
correct += int(round_correct.item())
total += int(round_total.item()) # correct += int(round_correct.item())
round_correct -= round_correct # total += int(round_total.item())
round_total -= round_total
# round_correct -= round_correct
print("Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}".format(correct, total, correct / total)) # round_total -= round_total
# print("Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}".format(correct, total, correct / total))
if round % 10 == 0:
print(round)
writer.close()
# Plan # Plan
# overall accuracy test: # overall accuracy test:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment