realm_index.py 11.9 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
from collections import defaultdict
import os
import pickle
import shutil

Neel Kant's avatar
Neel Kant committed
6
import faiss
Neel Kant's avatar
Neel Kant committed
7
8
9
10
11
12
import numpy as np
import torch

from megatron import get_args


13
14
15
16
def detach(tensor):
    return tensor.detach().cpu().numpy()


Neel Kant's avatar
Neel Kant committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class BlockData(object):
    def __init__(self):
        self.embed_data = dict()
        self.meta_data = dict()
        self.temp_dir_name = 'temp_block_data'

    def state(self):
        return {
            'embed_data': self.embed_data,
            'meta_data': self.meta_data
        }

    def clear(self):
        """Clear the data structures to save memory"""
        self.embed_data = dict()
        self.meta_data = dict()

    @classmethod
    def load_from_file(cls, fname):
        print(" > Unpickling block data")
        state_dict = pickle.load(open(fname, 'rb'))
        print(" > Finished unpickling")

        new_index = cls()
        new_index.embed_data = state_dict['embed_data']
        new_index.meta_data = state_dict['meta_data']
        return new_index

    def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
        for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
            if not allow_overwrite and idx in self.embed_data:
                raise ValueError("Unexpectedly tried to overwrite block data")

50
            self.embed_data[idx] = np.float16(embed)
Neel Kant's avatar
Neel Kant committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            self.meta_data[idx] = meta

    def save_shard(self, rank):
        if not os.path.isdir(self.temp_dir_name):
            os.mkdir(self.temp_dir_name)

        # save the data for each shard
        with open('{}/{}.pkl'.format(self.temp_dir_name, rank), 'wb') as data_file:
            pickle.dump(self.state(), data_file)

    def consolidate_shards_and_save(self, ignore_shard=0):
        """Combine all the shards made using self.save_shard()"""
        fnames = os.listdir(self.temp_dir_name)
        for fname in fnames:
            with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
                data = pickle.load(f)

                old_size = len(self.embed_data)
                shard_size = len(data['embed_data'])
                self.embed_data.update(data['embed_data'])
                self.meta_data.update(data['meta_data'])
                assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)

        args = get_args()
        with open(args.block_data_path, 'wb') as final_file:
            pickle.dump(self.state(), final_file)
        shutil.rmtree(self.temp_dir_name, ignore_errors=True)


class FaissMIPSIndex(object):
    def __init__(self, index_type, embed_size, **index_kwargs):
        self.index_type = index_type
        self.embed_size = embed_size
        self.index_kwargs = dict(index_kwargs)

        # alsh
        self.m = 5
        self.u = 0.99
        self.max_norm = None
        self.block_mips_index = self.get_block_index()

    @classmethod
    def load_from_file(cls, fname):
        print(" > Unpickling block index data")
        state_dict = pickle.load(open(fname, 'rb'))
        print(" > Finished unpickling")
        index_type = state_dict['index_type']
        index_kwargs = state_dict['index_kwargs']
        embed_size = state_dict['embed_size']

        new_index = cls(index_type, embed_size, **index_kwargs)

        return new_index

    def get_block_index(self):
        INDEX_TYPES = ['flat_l2', 'flat_ip']
        if self.index_type not in INDEX_TYPES:
            raise ValueError("Invalid index type specified")

        if self.index_type == 'flat_l2':
            index = faiss.IndexFlatL2(self.embed_size + 2 * self.m)
            return faiss.IndexIDMap(index)
        elif self.index_type == 'flat_ip':
            index = faiss.IndexFlatIP(self.embed_size)
            return faiss.IndexIDMap(index)

    def add_block_embed_data(self, all_block_data, clear_block_data=False):
        """Add the embedding of each block to the underlying FAISS index"""
        block_indices, block_embeds = zip(*all_block_data.embed_data.items())
        if clear_block_data:
            all_block_data.clear()

        if self.index_type == 'flat_l2':
            block_embeds = self.alsh_block_preprocess_fn(block_embeds)
Neel Kant's avatar
Neel Kant committed
125
        self.block_mips_index.add_with_ids(np.array(block_embeds), np.array(block_indices))
Neel Kant's avatar
Neel Kant committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

    def search_mips_index(self, query_embeds, top_k, reconstruct=True):
        """Get the top-k blocks by the index distance metric.

        :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
                            if False: return [num_queries x k] array of distances, and another for indices
        """
        if self.index_type == 'flat_l2':
            query_embeds = self.alsh_query_preprocess_fn(query_embeds)

        if reconstruct:
            top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
            return top_k_block_embeds
        else:
            distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
            return distances, block_indices

    def get_norm_powers_and_halves_array(self, embeds):
        norm = np.linalg.norm(embeds, axis=1)
        norm_powers = [np.multiply(norm, norm)]  # squared L2 norms of all
        for i in range(self.m - 1):
            norm_powers.append(np.multiply(norm_powers[-1], norm_powers[-1]))
        # [num_blocks x self.m]
        norm_powers = np.transpose(np.array(norm_powers))
        halves_array = 0.5 * np.ones(norm_powers.shape)

        return norm_powers, halves_array

    def alsh_block_preprocess_fn(self, block_embeds):
        block_embeds = np.array(block_embeds)
        if self.max_norm is None:
            self.max_norm = max(np.linalg.norm(block_embeds, axis=1))
        if self.max_norm > 1:
            block_embeds = self.u / self.max_norm * block_embeds
        norm_powers, halves_array = self.get_norm_powers_and_halves_array(block_embeds)

        # P'(S(x)) for all x in block_embeds
        return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1))

    def alsh_query_preprocess_fn(self, query_embeds):
        max_norm = max(np.linalg.norm(query_embeds, axis=1))
        if max_norm > 1:
            query_embeds = self.u / max_norm * query_embeds
        norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)

        # Q'(S(x)) for all x in query_embeds
        return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))


class RandProjectionLSHIndex(object):
    """Class for holding hashed data"""
    def __init__(self, embed_size, num_buckets, whiten=True, seed=0):
        np.random.seed(seed)
        self.hash_data = defaultdict(list)
        hash_matrix = 2 * np.random.rand(embed_size, int(num_buckets / 2)) - 1
        self.hash_matrix = hash_matrix / np.linalg.norm(hash_matrix, axis=0).reshape(1, -1)
        self.embed_mean = None
        self.embed_whitener = None
        self.whiten = whiten

    def state(self):
        state = {
            'hash_data': self.hash_data,
            'hash_matrix': self.hash_matrix,
            'embed_mean': self.embed_mean,
            'embed_whitener': self.embed_whitener,
        }
        return state

    def save_to_file(self):
        args = get_args()
        with open(args.block_index_path, 'wb') as index_file:
            pickle.dump(self.state(), index_file)

    @classmethod
    def load_from_file(cls, fname):
        print(" > Unpickling block hash data")
        state_dict = pickle.load(open(fname, 'rb'))
        print(" > Finished unpickling")
        hash_matrix = state_dict['hash_matrix']

        new_index = cls(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
        new_index.hash_data = state_dict['hash_data']
        new_index.embed_mean = state_dict.get('embed_mean')
        new_index.embed_whitener = state_dict.get('embed_whitener')
        new_index.hash_matrix = hash_matrix

        return new_index

    def get_block_bucket(self, hash):
        return self.hash_data[hash]

    def hash_embeds(self, embeds, write_block_data=None):
        """Hash a tensor of embeddings using a random projection matrix"""
Neel Kant's avatar
Neel Kant committed
220
        embed_scores_pos = torch.matmul(embeds, torch.cuda.FloatTensor(self.hash_matrix).type(embeds.dtype))
Neel Kant's avatar
Neel Kant committed
221
222
223
224
225
226
227
228
229
230
231
232
        embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
        embed_hashes = detach(torch.argmax(embed_scores, axis=1))

        if write_block_data is not None:
            for hash, indices in zip(embed_hashes, write_block_data):
                self.hash_data[hash].append(indices)

        return embed_hashes

    def hash_whitened_block_embeds(self, block_data):
        """Transform all block embeds to have zero mean and unit covariance
        when treated as samples from a distribution"""
233
        block_idx, all_embeds = zip(*block_data.embed_data.items())
Neel Kant's avatar
Neel Kant committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        arr_embeds = np.transpose(np.array(all_embeds))

        mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
        centered = arr_embeds - mean
        inv_cov = np.linalg.inv(np.cov(arr_embeds))
        whitener = np.transpose(np.linalg.cholesky(inv_cov))
        whitened = np.float16(np.transpose(whitener.dot(centered)))

        self.embed_mean = mean.reshape(-1)
        self.embed_whitener = whitener
        self.hash_data = defaultdict(list)
        batch_size = 16384
        i = 0

        args = get_args()
        with torch.no_grad():
            while True:
                if args.debug:
                    print(i, flush=True)
                batch_slice = slice(i * batch_size, (i + 1) * batch_size)
                batch_embed = torch.cuda.HalfTensor(whitened[batch_slice])
                batch_meta = [block_data.meta_data[idx] for idx in block_idx[batch_slice]]
                if len(batch_meta) == 0:
                    break

                self.hash_embeds(batch_embed, batch_meta)
                i += 1

    def exact_mips_equals(self, query_embeds, all_block_data, norm_blocks):
        """For each query, determine whether the mips block is in the correct hash bucket"""
        shuffled_block_idx, block_embeds = zip(*all_block_data.items())
        if norm_blocks:
            block_embeds = block_embeds / np.linalg.norm(block_embeds, axis=1).reshape(-1, 1)
        with torch.no_grad():
            query_hashes = self.hash_embeds(query_embeds)

            # [num_query x num_blocks]
            inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds),
                                          torch.cuda.HalfTensor(np.transpose(np.array(block_embeds))))
            max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1))
            best_blocks = np.array([all_block_data[shuffled_block_idx[idx]] for idx in max_inner_product_idxes])
            best_block_hashes = self.hash_embeds(best_blocks)

            print('Query hashes: ', query_hashes)
            print('Block hashes: ', best_block_hashes)
            equal_arr = np.equal(query_hashes, best_block_hashes).astype(int)

            # array of zeros and ones which can be used for counting success
            return equal_arr

    def exact_mips_test(self, num_queries, all_block_data, norm_blocks):
        if self.whiten:
            if self.embed_mean is None:
                self.hash_whitened_block_embeds(all_block_data)
            embed_size = self.hash_matrix.shape[0]
            query_embeds = np.random.multivariate_normal(np.zeros(embed_size), np.eye(embed_size), num_queries)
            query_embeds = query_embeds / np.linalg.norm(query_embeds, axis=1).reshape(-1, 1)
        else:
            block_idx, all_embeds = zip(*all_block_data.items())
            arr_embeds = np.transpose(np.array(all_embeds))

            mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
            cov = np.cov(arr_embeds)
            query_embeds = np.random.multivariate_normal(mean, cov, num_queries)

        equal_arr = self.exact_mips_equals(query_embeds, all_block_data, norm_blocks)
        print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
        print(equal_arr)