ict_eval_bm25.py 4.21 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import lucene

from java.nio.file import Paths
from org.apache.lucene.analysis.standard import StandardAnalyzer
from org.apache.lucene.document import Document, Field, FieldType
from org.apache.lucene.index import IndexWriter, IndexWriterConfig, IndexOptions, DirectoryReader
from org.apache.lucene.store import SimpleFSDirectory
from org.apache.lucene.search import IndexSearcher
from org.apache.lucene.queryparser.classic import QueryParser
from org.apache.lucene.search.similarities import BM25Similarity
from org.apache.lucene.util import Version

import torch
import torch.distributed as dist

from indexer import get_ict_dataset, get_one_epoch_dataloader
from megatron.initialize import initialize_megatron
from pretrain_bert_ict import get_batch


def setup():
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    lucene.initVM(vmargs=['-Djava.awt.headless=true'])


def run():
    dset = get_ict_dataset(use_titles=False, query_in_block_prob=0.1)
    dataloader = iter(get_one_epoch_dataloader(dset))

    index_dir = SimpleFSDirectory(Paths.get("index/"))
    analyzer = StandardAnalyzer()
    analyzer.setMaxTokenLength(1024)

    # field for document ID
    t1 = FieldType()
    t1.setStored(True)
    t1.setTokenized(False)

    # field for document text
    t2 = FieldType()
    t2.setStored(True)
    t2.setTokenized(True)
    t2.setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS)

    correct = total = 0
    round_correct = torch.zeros(1).cuda()
    round_total = torch.zeros(1).cuda()
    for round in range(100):
        with torch.no_grad():
            try:
                query_tokens, query_pad_mask, \
                block_tokens, block_pad_mask, block_index_data = get_batch(dataloader)
            except:
                break

        query_tokens = query_tokens.detach().cpu().numpy()
        block_tokens = block_tokens.detach().cpu().numpy()

        query_strs = [dset.decode_tokens(query_tokens[i].tolist(), hardcore=True) for i in range(query_tokens.shape[0])]
        block_strs = [dset.decode_tokens(block_tokens[i].tolist(), hardcore=True) for i in range(block_tokens.shape[0])]

        # create index writer

        config = IndexWriterConfig(analyzer)
        config.setOpenMode(IndexWriterConfig.OpenMode.CREATE)

        writer = IndexWriter(index_dir, config)

        def add_document(text, writer, doc_id):
            doc = Document()
            doc.add(Field("text", text, t2))
            doc.add(Field("doc_id", doc_id, t1))
            writer.addDocument(doc)

        # add documents to index writer
        for i in range(len(block_strs)):
            add_document(block_strs[i], writer, i)

        # write and finalize the index
        writer.commit()
        writer.close()

        # define BM25 searcher
        searcher = IndexSearcher(DirectoryReader.open(index_dir))
        searcher.setSimilarity(BM25Similarity())

        # feed queries and get scores for everything in the index
        hits_list = []
        for s in query_strs:
            query = QueryParser("text", analyzer).parse(s)
            hits = searcher.search(query, 8).scoreDocs
            hits_list.append(hits)

        for (i, hits) in enumerate(hits_list):
            doc_ids = [int(searcher.doc(hit.doc)['doc_id']) for hit in hits]
            correct += int(i in doc_ids)
            total += 1

        dist.all_reduce(round_correct)
        dist.all_reduce(round_total)
        correct += int(round_correct.item())
        total += int(round_total.item())
        round_correct -= round_correct
        round_total -= round_total

        print("Correct: {:8d}   |   Total: {:8d}   |   Fraction: {:6.5f}".format(correct, total, correct / total))

    # Plan
    # overall accuracy test:
    # have index with all blocks. For BERT these are token ids, for BM25 these are tokens
    #
    # 1. run batch size 4096 BM25 self similarity test. For this I can just detokenize out of the dataset.
    # I get the retrieval scores in the forward_step and log the results.
    # 2. Create a BM25 index over all of wikipedia, have it ready for use in megatron QA.
    #
    # Create an index with the block embeddings with block ids

if __name__ == "__main__":
    setup()
    run()