Commit 8409e1c1 authored by Neel Kant's avatar Neel Kant
Browse files

Add bm25 evaluation code

parent dfb907fe
import lucene
from java.nio.file import Paths
from org.apache.lucene.analysis.standard import StandardAnalyzer
from org.apache.lucene.document import Document, Field, FieldType
from org.apache.lucene.index import IndexWriter, IndexWriterConfig, IndexOptions, DirectoryReader
from org.apache.lucene.store import SimpleFSDirectory
from org.apache.lucene.search import IndexSearcher
from org.apache.lucene.queryparser.classic import QueryParser
from org.apache.lucene.search.similarities import BM25Similarity
from org.apache.lucene.util import Version
import torch
import torch.distributed as dist
from indexer import get_ict_dataset, get_one_epoch_dataloader
from megatron.initialize import initialize_megatron
from pretrain_bert_ict import get_batch
def setup():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
lucene.initVM(vmargs=['-Djava.awt.headless=true'])
def run():
dset = get_ict_dataset(use_titles=False, query_in_block_prob=0.1)
dataloader = iter(get_one_epoch_dataloader(dset))
index_dir = SimpleFSDirectory(Paths.get("index/"))
analyzer = StandardAnalyzer()
analyzer.setMaxTokenLength(1024)
# field for document ID
t1 = FieldType()
t1.setStored(True)
t1.setTokenized(False)
# field for document text
t2 = FieldType()
t2.setStored(True)
t2.setTokenized(True)
t2.setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS)
correct = total = 0
round_correct = torch.zeros(1).cuda()
round_total = torch.zeros(1).cuda()
for round in range(100):
with torch.no_grad():
try:
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(dataloader)
except:
break
query_tokens = query_tokens.detach().cpu().numpy()
block_tokens = block_tokens.detach().cpu().numpy()
query_strs = [dset.decode_tokens(query_tokens[i].tolist(), hardcore=True) for i in range(query_tokens.shape[0])]
block_strs = [dset.decode_tokens(block_tokens[i].tolist(), hardcore=True) for i in range(block_tokens.shape[0])]
# create index writer
config = IndexWriterConfig(analyzer)
config.setOpenMode(IndexWriterConfig.OpenMode.CREATE)
writer = IndexWriter(index_dir, config)
def add_document(text, writer, doc_id):
doc = Document()
doc.add(Field("text", text, t2))
doc.add(Field("doc_id", doc_id, t1))
writer.addDocument(doc)
# add documents to index writer
for i in range(len(block_strs)):
add_document(block_strs[i], writer, i)
# write and finalize the index
writer.commit()
writer.close()
# define BM25 searcher
searcher = IndexSearcher(DirectoryReader.open(index_dir))
searcher.setSimilarity(BM25Similarity())
# feed queries and get scores for everything in the index
hits_list = []
for s in query_strs:
query = QueryParser("text", analyzer).parse(s)
hits = searcher.search(query, 8).scoreDocs
hits_list.append(hits)
for (i, hits) in enumerate(hits_list):
doc_ids = [int(searcher.doc(hit.doc)['doc_id']) for hit in hits]
correct += int(i in doc_ids)
total += 1
dist.all_reduce(round_correct)
dist.all_reduce(round_total)
correct += int(round_correct.item())
total += int(round_total.item())
round_correct -= round_correct
round_total -= round_total
print("Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}".format(correct, total, correct / total))
# Plan
# overall accuracy test:
# have index with all blocks. For BERT these are token ids, for BM25 these are tokens
#
# 1. run batch size 4096 BM25 self similarity test. For this I can just detokenize out of the dataset.
# I get the retrieval scores in the forward_step and log the results.
# 2. Create a BM25 index over all of wikipedia, have it ready for use in megatron QA.
#
# Create an index with the block embeddings with block ids
if __name__ == "__main__":
setup()
run()
...@@ -240,7 +240,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad= ...@@ -240,7 +240,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
return model return model
def get_ict_dataset(use_titles=True): def get_ict_dataset(use_titles=True, query_in_block_prob=1):
args = get_args() args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
...@@ -255,7 +255,7 @@ def get_ict_dataset(use_titles=True): ...@@ -255,7 +255,7 @@ def get_ict_dataset(use_titles=True):
max_seq_length=args.seq_length, max_seq_length=args.seq_length,
short_seq_prob=0.0001, # doesn't matter short_seq_prob=0.0001, # doesn't matter
seed=1, seed=1,
query_in_block_prob=1, query_in_block_prob=query_in_block_prob,
use_titles=use_titles use_titles=use_titles
) )
dataset = ICTDataset(**kwargs) dataset = ICTDataset(**kwargs)
......
...@@ -134,10 +134,30 @@ class ICTDataset(Dataset): ...@@ -134,10 +134,30 @@ class ICTDataset(Dataset):
def encode_text(self, text): def encode_text(self, text):
return self.tokenizer.tokenize(text) return self.tokenizer.tokenize(text)
def decode_tokens(self, token_ids): def decode_tokens(self, token_ids, hardcore=False):
tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids) tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
non_pads = [t for t in tokens if t != '[PAD]'] exclude_list = ['[PAD]', '[CLS]']
return join_str_list(non_pads) if hardcore:
extra_exclude = ['[SEP]']
exclude_list.extend(extra_exclude)
non_pads = [t for t in tokens if t not in exclude_list]
joined_strs = join_str_list(non_pads)
if hardcore:
escape_chars = ['+', '-', '&', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':', '/']
skip_me = False
joined_strs = list(joined_strs)
joined_strs = [s for s in joined_strs if s != '\\']
for i, c in enumerate(joined_strs):
if skip_me:
skip_me = False
continue
if c in escape_chars:
joined_strs.insert(i, '\\')
skip_me = True
joined_strs = ''.join(joined_strs)
if len(joined_strs) < 3:
joined_strs += 'text here'
return joined_strs
def get_block(self, start_idx, end_idx, doc_idx): def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document""" """Get the IDs for an evidence block plus the title of the corresponding document"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment