evaluate_utils.py 7.13 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mostofa Patwary's avatar
Mostofa Patwary committed
2
3
4

import torch

xingjinliang's avatar
xingjinliang committed
5
6
7
8
9
from megatron.training import get_args, print_rank_0
from megatron.training.checkpointing import load_biencoder_checkpoint
from megatron.legacy.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.legacy.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.legacy.model.biencoder_model import get_model_provider
Mostofa Patwary's avatar
Mostofa Patwary committed
10
from megatron.training import get_model
Mostofa Patwary's avatar
Mostofa Patwary committed
11
12
13
14
15
from tasks.orqa.unsupervised.nq import get_nq_dataset
from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
from tasks.orqa.unsupervised.nq import process_nq_batch
from tasks.orqa.unsupervised.qa_utils import calculate_matches

Mostofa Patwary's avatar
Mostofa Patwary committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

class ORQAEvaluator(object):
    def __init__(self):
        args = get_args()
        self.embedding_size = args.hidden_size
        self.faiss_use_gpu = args.faiss_use_gpu
        self.evidence_embedder_obj = None
        self.evidence_dataset = None
        self.mips_index = None
        self.eval_dataset = None

        # Get Evidence (Wikipedia) dataset
        self.get_evidence_dataset()

        # Load query encoder checkpoint
        only_query_model = True
        if args.biencoder_shared_query_context_model:
            only_query_model = False

Mostofa Patwary's avatar
Mostofa Patwary committed
35
        model = get_model(get_model_provider(only_query_model=only_query_model,
Mostofa Patwary's avatar
Mostofa Patwary committed
36
37
            biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))

Mostofa Patwary's avatar
Mostofa Patwary committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        self.model = load_biencoder_checkpoint(model,
                only_query_model=only_query_model)

        assert len(self.model) == 1
        self.model[0].eval()

        # Load faiss indexer
        self.faiss_wrapper()

    def get_evidence_embedding(self):
        # This will load the embedding from the embedding path
        self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)

    def get_evidence_dataset(self):
        self.evidence_dataset = get_open_retrieval_wiki_dataset()

    def faiss_wrapper(self):
        # Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
        # is distributed over all the GPUs in a node and FAISS is not 
        # thread-safe
        args = get_args()
        if args.local_rank == 0:
            # Get evidence embeddings computed using context encoder
            self.get_evidence_embedding()

            assert self.evidence_embedder_obj is not None
            self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
                                        embed_data=self.evidence_embedder_obj,
                                        use_gpu=self.faiss_use_gpu)

        # Wait for the FAISS index to be initialized in all the nodes
        torch.distributed.barrier()

    def generate_query_vectors(self, qa_data, split):

        self.eval_dataset = get_nq_dataset(qa_data, split)
        dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)

        query_vectors = []
        reference_list = []

        for batch in dataloader:
            # batch also has query_tokens and query_pad_data
            query_tokens, query_mask, query_types, \
                query_len, reference = process_nq_batch(batch)

            assert len(self.model) == 1
            unwrapped_model = self.model[0]
            while not hasattr(unwrapped_model, 'embed_text'):
                unwrapped_model = unwrapped_model.module

            with torch.no_grad():
                query_logits = unwrapped_model.embed_text(
                    unwrapped_model.query_model, query_tokens, 
                    query_mask, query_types)

            reference_list.extend(reference)
            query_vectors.extend(query_logits.split(1, dim=0))
            if len(query_vectors) % 100 == 0:
                print_rank_0('Encoded queries {}'.format(len(query_vectors)))

        query_tensor = torch.cat(query_vectors, dim=0)
        print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))

        assert query_tensor.size(0) == len(self.eval_dataset)
        return query_tensor, reference_list

    def evaluate(self, qa_data, split):
        args = get_args()
        query_tensor, reference_list = self.generate_query_vectors(qa_data, \
                                                                    split)
        local_rank = args.local_rank
        rank = torch.distributed.get_rank()
        device_count = torch.cuda.device_count()
        num_nodes = torch.distributed.get_world_size() // device_count
        node_id = rank // device_count

        for node in range(num_nodes):
            start_rank = node * device_count
            end_rank = (node + 1) * device_count
            ranks_list = list(range(start_rank, end_rank))
            node_group = torch.distributed.new_group(ranks=ranks_list)

            if node_id == node:
                device_start_rank = start_rank
                group = node_group
        
        input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
        tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
        torch.distributed.all_gather(tensor_list, query_tensor, group=group)

        if local_rank == 0 and self.mips_index is not None:
            all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()

            distance, topkindex = self.mips_index.search_mips_index(
                all_query_tensor, top_k=args.faiss_topk_retrievals, 
                reconstruct=False)
            distance = torch.from_numpy(distance).cuda()
            topkindex = torch.LongTensor(topkindex).cuda()

        if local_rank != 0:
            distance = torch.empty(device_count * len(query_tensor), \
                args.faiss_topk_retrievals, dtype=torch.float32).cuda()
            topkindex = torch.empty(device_count * len(query_tensor), \
                args.faiss_topk_retrievals, dtype=torch.int64).cuda()

        torch.distributed.broadcast(distance, src=device_start_rank, \
            group=group)
        torch.distributed.broadcast(topkindex, src=device_start_rank, \
            group=group)

        distance = torch.split(distance, len(query_tensor), dim=0)\
            [local_rank]
        topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
            [local_rank]

        top_ids_and_scores = []
        for darray, topkarray in zip(distance, topkindex):
            top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))

        passages = self.evidence_dataset.id2text
        match_stats = calculate_matches(passages,
                                        reference_list,
                                        top_ids_and_scores,
                                        workers_num=args.num_workers,
                                        match_type=args.faiss_match)
        top_k_hits = match_stats.top_k_hits

        print_rank_0("{} SET RESULTS".format(split))
        print_rank_0("topk-{} documents hits {}".format(
            args.faiss_topk_retrievals, top_k_hits))
        top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
        print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))

        for i in args.retriever_report_topk_accuracies:
            print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))

        return