Commit 8b355d93 authored by Neel Kant's avatar Neel Kant
Browse files

Temporarily modify ict_eval_bm25 to build full wiki index

parent 8409e1c1
import lucene
import sys
from java.nio.file import Paths
from org.apache.lucene.analysis.standard import StandardAnalyzer
......@@ -24,14 +25,19 @@ def setup():
lucene.initVM(vmargs=['-Djava.awt.headless=true'])
def run():
def run(embed_all=False):
dset = get_ict_dataset(use_titles=False, query_in_block_prob=0.1)
dataloader = iter(get_one_epoch_dataloader(dset))
index_dir = SimpleFSDirectory(Paths.get("index/"))
index_dir = SimpleFSDirectory(Paths.get("full_wiki_index/"))
analyzer = StandardAnalyzer()
analyzer.setMaxTokenLength(1024)
config = IndexWriterConfig(analyzer)
config.setOpenMode(IndexWriterConfig.OpenMode.CREATE)
writer = IndexWriter(index_dir, config)
# field for document ID
t1 = FieldType()
t1.setStored(True)
......@@ -46,7 +52,7 @@ def run():
correct = total = 0
round_correct = torch.zeros(1).cuda()
round_total = torch.zeros(1).cuda()
for round in range(100):
for round in range(100000):
with torch.no_grad():
try:
query_tokens, query_pad_mask, \
......@@ -54,19 +60,12 @@ def run():
except:
break
query_tokens = query_tokens.detach().cpu().numpy()
# query_tokens = query_tokens.detach().cpu().numpy()
block_tokens = block_tokens.detach().cpu().numpy()
query_strs = [dset.decode_tokens(query_tokens[i].tolist(), hardcore=True) for i in range(query_tokens.shape[0])]
# query_strs = [dset.decode_tokens(query_tokens[i].tolist(), hardcore=True) for i in range(query_tokens.shape[0])]
block_strs = [dset.decode_tokens(block_tokens[i].tolist(), hardcore=True) for i in range(block_tokens.shape[0])]
# create index writer
config = IndexWriterConfig(analyzer)
config.setOpenMode(IndexWriterConfig.OpenMode.CREATE)
writer = IndexWriter(index_dir, config)
def add_document(text, writer, doc_id):
doc = Document()
doc.add(Field("text", text, t2))
......@@ -79,32 +78,36 @@ def run():
# write and finalize the index
writer.commit()
writer.close()
# define BM25 searcher
searcher = IndexSearcher(DirectoryReader.open(index_dir))
searcher.setSimilarity(BM25Similarity())
# feed queries and get scores for everything in the index
hits_list = []
for s in query_strs:
query = QueryParser("text", analyzer).parse(s)
hits = searcher.search(query, 8).scoreDocs
hits_list.append(hits)
for (i, hits) in enumerate(hits_list):
doc_ids = [int(searcher.doc(hit.doc)['doc_id']) for hit in hits]
correct += int(i in doc_ids)
total += 1
dist.all_reduce(round_correct)
dist.all_reduce(round_total)
correct += int(round_correct.item())
total += int(round_total.item())
round_correct -= round_correct
round_total -= round_total
print("Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}".format(correct, total, correct / total))
# searcher = IndexSearcher(DirectoryReader.open(index_dir))
# searcher.setSimilarity(BM25Similarity())
# # feed queries and get scores for everything in the index
# hits_list = []
# for s in query_strs:
# query = QueryParser("text", analyzer).parse(s)
# hits = searcher.search(query, 1).scoreDocs
# hits_list.append(hits)
# for (i, hits) in enumerate(hits_list):
# doc_ids = [int(searcher.doc(hit.doc)['doc_id']) for hit in hits]
# correct += int(i in doc_ids)
# total += 1
# dist.all_reduce(round_correct)
# dist.all_reduce(round_total)
# correct += int(round_correct.item())
# total += int(round_total.item())
# round_correct -= round_correct
# round_total -= round_total
# print("Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}".format(correct, total, correct / total))
if round % 10 == 0:
print(round)
writer.close()
# Plan
# overall accuracy test:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment