realm_index.py 8.02 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
import itertools
import os
import pickle
import shutil

import numpy as np
import torch

Neel Kant's avatar
Neel Kant committed
9
from megatron import get_args
10
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
11
12
13
14
15
16


def detach(tensor):
    return tensor.detach().cpu().numpy()


Mostofa Patwary's avatar
Mostofa Patwary committed
17
class OpenRetreivalDataStore(object):
18
19
20
21
    """
    Serializable data structure for holding data for blocks --
    embeddings and necessary metadata for Retriever
    """
Mostofa Patwary's avatar
Mostofa Patwary committed
22
    def __init__(self, embedding_path=None, load_from_path=True, rank=None):
Neel Kant's avatar
Neel Kant committed
23
        self.embed_data = dict()
Mostofa Patwary's avatar
Mostofa Patwary committed
24
        if embedding_path is None:
Neel Kant's avatar
Neel Kant committed
25
            args = get_args()
Mostofa Patwary's avatar
Mostofa Patwary committed
26
            embedding_path = args.embedding_path
Neel Kant's avatar
Neel Kant committed
27
            rank = args.rank
Mostofa Patwary's avatar
Mostofa Patwary committed
28
        self.embedding_path = embedding_path
Neel Kant's avatar
Neel Kant committed
29
30
        self.rank = rank

Neel Kant's avatar
Neel Kant committed
31
32
33
        if load_from_path:
            self.load_from_file()

Mostofa Patwary's avatar
Mostofa Patwary committed
34
        block_data_name = os.path.splitext(self.embedding_path)[0]
Neel Kant's avatar
Neel Kant committed
35
36
37
38
39
40
41
42
        self.temp_dir_name = block_data_name + '_tmp'

    def state(self):
        return {
            'embed_data': self.embed_data,
        }

    def clear(self):
43
44
45
46
        """
        Clear the embedding data structures to save memory.
        The metadata ends up getting used, and is also much smaller in
        dimensionality so it isn't really worth clearing.
Neel Kant's avatar
Neel Kant committed
47
48
49
        """
        self.embed_data = dict()

Neel Kant's avatar
Neel Kant committed
50
51
52
    def load_from_file(self):
        """Populate members from instance saved to file"""

53
54
        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print("\n> Unpickling BlockData", flush=True)
Mostofa Patwary's avatar
Mostofa Patwary committed
55
        state_dict = pickle.load(open(self.embedding_path, 'rb'))
56
57
        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print(">> Finished unpickling BlockData\n", flush=True)
Neel Kant's avatar
Neel Kant committed
58

Neel Kant's avatar
Neel Kant committed
59
        self.embed_data = state_dict['embed_data']
Neel Kant's avatar
Neel Kant committed
60

Mostofa Patwary's avatar
Mostofa Patwary committed
61
    def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
62
63
        """
        Add data for set of blocks
Mostofa Patwary's avatar
Mostofa Patwary committed
64
        :param row_id: 1D array of unique int ids for the blocks
Neel Kant's avatar
Neel Kant committed
65
        :param block_embeds: 2D array of embeddings of the blocks
66
            In the case of retriever this will be [start_idx, end_idx, doc_idx]
Neel Kant's avatar
Neel Kant committed
67
        """
Mostofa Patwary's avatar
Mostofa Patwary committed
68
        for idx, embed in zip(row_id, block_embeds):
Neel Kant's avatar
Neel Kant committed
69
70
71
72
73
74
            if not allow_overwrite and idx in self.embed_data:
                raise ValueError("Unexpectedly tried to overwrite block data")

            self.embed_data[idx] = np.float16(embed)

    def save_shard(self):
75
76
77
        """
        Save the block data that was created this in this process
        """
Neel Kant's avatar
Neel Kant committed
78
79
80
81
        if not os.path.isdir(self.temp_dir_name):
            os.makedirs(self.temp_dir_name, exist_ok=True)

        # save the data for each shard
82
83
        with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
            as writer:
Mostofa Patwary's avatar
Mostofa Patwary committed
84
            pickle.dump(self.state(), writer)
Neel Kant's avatar
Neel Kant committed
85
86

    def merge_shards_and_save(self):
87
        #Combine all the shards made using save_shard
Neel Kant's avatar
Neel Kant committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        shard_names = os.listdir(self.temp_dir_name)
        seen_own_shard = False

        for fname in os.listdir(self.temp_dir_name):
            shard_rank = int(os.path.splitext(fname)[0])
            if shard_rank == self.rank:
                seen_own_shard = True
                continue

            with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
                data = pickle.load(f)
                old_size = len(self.embed_data)
                shard_size = len(data['embed_data'])

102
103
                # add the shard's data and check to make sure there
                # is no overlap
Neel Kant's avatar
Neel Kant committed
104
105
106
107
108
109
                self.embed_data.update(data['embed_data'])
                assert len(self.embed_data) == old_size + shard_size

        assert seen_own_shard

        # save the consolidated shards and remove temporary directory
Mostofa Patwary's avatar
Mostofa Patwary committed
110
        with open(self.embedding_path, 'wb') as final_file:
Neel Kant's avatar
Neel Kant committed
111
112
113
114
115
116
117
118
            pickle.dump(self.state(), final_file)
        shutil.rmtree(self.temp_dir_name, ignore_errors=True)

        print("Finished merging {} shards for a total of {} embeds".format(
            len(shard_names), len(self.embed_data)), flush=True)


class FaissMIPSIndex(object):
Mostofa Patwary's avatar
Mostofa Patwary committed
119
120
121
122
    """
    Wrapper object for a BlockData which similarity search via FAISS under the hood
    """
    def __init__(self, embed_size, embed_data=None, use_gpu=False):
Neel Kant's avatar
Neel Kant committed
123
        self.embed_size = embed_size
Mostofa Patwary's avatar
Mostofa Patwary committed
124
        self.embed_data = embed_data
Neel Kant's avatar
Neel Kant committed
125
126
        self.use_gpu = use_gpu

Mostofa Patwary's avatar
Mostofa Patwary committed
127
128
        self.mips_index = None
        self._set_mips_index()
Neel Kant's avatar
Neel Kant committed
129

Mostofa Patwary's avatar
Mostofa Patwary committed
130
131
132
133
134
    def _set_mips_index(self):
        """
        Create a Faiss Flat index with inner product as the metric
        to search against
        """
Neel Kant's avatar
Neel Kant committed
135
136
137
138
139
        try:
            import faiss
        except ImportError:
            raise Exception("Error: Please install faiss to use FaissMIPSIndex")

140
141
        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print("\n> Building index", flush=True)
Mostofa Patwary's avatar
Mostofa Patwary committed
142
143

        cpu_index = faiss.IndexFlatIP(self.embed_size)
Neel Kant's avatar
Neel Kant committed
144
145
146

        if self.use_gpu:
            # create resources and config for GpuIndex
Mostofa Patwary's avatar
Mostofa Patwary committed
147
148
            config = faiss.GpuMultipleClonerOptions()
            config.shard = True
Neel Kant's avatar
Neel Kant committed
149
            config.useFloat16 = True
Mostofa Patwary's avatar
Mostofa Patwary committed
150
151
            gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
            self.mips_index = faiss.IndexIDMap(gpu_index)
152
            if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
153
                print(">> Initialized index on GPU", flush=True)
Neel Kant's avatar
Neel Kant committed
154
155
        else:
            # CPU index supports IDs so wrap with IDMap
Mostofa Patwary's avatar
Mostofa Patwary committed
156
            self.mips_index = faiss.IndexIDMap(cpu_index)
157
158
            if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
                print(">> Initialized index on CPU", flush=True)
Neel Kant's avatar
Neel Kant committed
159

Mostofa Patwary's avatar
Mostofa Patwary committed
160
161
162
163
        # if we were constructed with a BlockData, then automatically load it
        # when the FAISS structure is built
        if self.embed_data is not None:
            self.add_embed_data(self.embed_data)
Neel Kant's avatar
Neel Kant committed
164
165

    def reset_index(self):
Mostofa Patwary's avatar
Mostofa Patwary committed
166
167
        """Delete existing index and create a new"""
        del self.mips_index
Neel Kant's avatar
Neel Kant committed
168
169

        # reset the block data so that _set_block_index will reload it as well
Mostofa Patwary's avatar
Mostofa Patwary committed
170
171
172
173
174
175
        if self.embed_data is not None:
            embed_data_path = self.embed_data.embedding_path
            del self.embed_data
            self.embed_data = OpenRetreivalDataStore(embed_data_path)

        self._set_mips_index()
Neel Kant's avatar
Neel Kant committed
176

Mostofa Patwary's avatar
Mostofa Patwary committed
177
178
179
    def update_index(self):
        """Delete existing index and create a new"""
        del self.mips_index
Neel Kant's avatar
Neel Kant committed
180

Mostofa Patwary's avatar
Mostofa Patwary committed
181
182
183
184
185
186
        # reset the block data so that _set_mips_index will reload it as well
        if self.embed_data is not None:
            self.embed_data.load_from_file()
        self._set_mips_index()

    def add_embed_data(self, all_embed_data):
Neel Kant's avatar
Neel Kant committed
187
        """Add the embedding of each block to the underlying FAISS index"""
Neel Kant's avatar
Neel Kant committed
188
189

        # this assumes the embed_data is a dict : {int: np.array<float>}
Mostofa Patwary's avatar
Mostofa Patwary committed
190
        block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
Neel Kant's avatar
Neel Kant committed
191

Mostofa Patwary's avatar
Mostofa Patwary committed
192
193
194
195
        # the embeddings have to be entered in as float32 even though the math
        # internally is done with float16.
        embeds_arr = np.float32(np.array(block_embeds))
        indices_arr = np.array(block_indices)
Neel Kant's avatar
Neel Kant committed
196

Neel Kant's avatar
Neel Kant committed
197
        # we no longer need the embedding data since it's in the index now
Mostofa Patwary's avatar
Mostofa Patwary committed
198
        all_embed_data.clear()
Neel Kant's avatar
Neel Kant committed
199

Mostofa Patwary's avatar
Mostofa Patwary committed
200
        self.mips_index.add_with_ids(embeds_arr, indices_arr)
Neel Kant's avatar
Neel Kant committed
201

202
203
        if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
            print(">>> Finished adding block data to index", flush=True)
Neel Kant's avatar
Neel Kant committed
204
205

    def search_mips_index(self, query_embeds, top_k, reconstruct=True):
Mostofa Patwary's avatar
Mostofa Patwary committed
206
207
        """
        Get the top-k blocks by the index distance metric.
Neel Kant's avatar
Neel Kant committed
208

Mostofa Patwary's avatar
Mostofa Patwary committed
209
210
211
212
        :param reconstruct: if True: return a [num_queries x k x embed_dim]
                                array of blocks
                            if False: return [num_queries x k] array of
                                distances, and another for indices
Neel Kant's avatar
Neel Kant committed
213
214
        """
        query_embeds = np.float32(detach(query_embeds))
Neel Kant's avatar
Neel Kant committed
215
216
217

        if reconstruct:
            # get the vectors themselves
Mostofa Patwary's avatar
Mostofa Patwary committed
218
219
            top_k_block_embeds = self.mips_index.search_and_reconstruct(\
                query_embeds, top_k)
Neel Kant's avatar
Neel Kant committed
220
221
222
            return top_k_block_embeds
        else:
            # get distances and indices of closest vectors
Mostofa Patwary's avatar
Mostofa Patwary committed
223
            distances, block_indices = self.mips_index.search(query_embeds, top_k)
Neel Kant's avatar
Neel Kant committed
224
            return distances, block_indices