realm_index.py 8.2 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Neel Kant's avatar
Neel Kant committed
2
3
4
5
6
7
8
9
import itertools
import os
import pickle
import shutil

import numpy as np
import torch

xingjinliang's avatar
xingjinliang committed
10
from megatron.training import get_args
11
from megatron.core import mpu
Neel Kant's avatar
Neel Kant committed
12
13
14
15
16
17


def detach(tensor):
    return tensor.detach().cpu().numpy()


Mostofa Patwary's avatar
Mostofa Patwary committed
18
class OpenRetreivalDataStore(object):
19
20
21
22
    """
    Serializable data structure for holding data for blocks --
    embeddings and necessary metadata for Retriever
    """
Mostofa Patwary's avatar
Mostofa Patwary committed
23
    def __init__(self, embedding_path=None, load_from_path=True, rank=None):
Neel Kant's avatar
Neel Kant committed
24
        self.embed_data = dict()
Mostofa Patwary's avatar
Mostofa Patwary committed
25
        if embedding_path is None:
Neel Kant's avatar
Neel Kant committed
26
            args = get_args()
Mostofa Patwary's avatar
Mostofa Patwary committed
27
            embedding_path = args.embedding_path
Neel Kant's avatar
Neel Kant committed
28
            rank = args.rank
Mostofa Patwary's avatar
Mostofa Patwary committed
29
        self.embedding_path = embedding_path
Neel Kant's avatar
Neel Kant committed
30
31
        self.rank = rank

Neel Kant's avatar
Neel Kant committed
32
33
34
        if load_from_path:
            self.load_from_file()

Mostofa Patwary's avatar
Mostofa Patwary committed
35
        block_data_name = os.path.splitext(self.embedding_path)[0]
Neel Kant's avatar
Neel Kant committed
36
37
38
39
40
41
42
43
        self.temp_dir_name = block_data_name + '_tmp'

    def state(self):
        return {
            'embed_data': self.embed_data,
        }

    def clear(self):
44
45
46
47
        """
        Clear the embedding data structures to save memory.
        The metadata ends up getting used, and is also much smaller in
        dimensionality so it isn't really worth clearing.
Neel Kant's avatar
Neel Kant committed
48
49
50
        """
        self.embed_data = dict()

Neel Kant's avatar
Neel Kant committed
51
52
53
    def load_from_file(self):
        """Populate members from instance saved to file"""

54
        if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
55
            print("\n> Unpickling BlockData", flush=True)
Mostofa Patwary's avatar
Mostofa Patwary committed
56
        state_dict = pickle.load(open(self.embedding_path, 'rb'))
57
        if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
58
            print(">> Finished unpickling BlockData\n", flush=True)
Neel Kant's avatar
Neel Kant committed
59

Neel Kant's avatar
Neel Kant committed
60
        self.embed_data = state_dict['embed_data']
Neel Kant's avatar
Neel Kant committed
61

Mostofa Patwary's avatar
Mostofa Patwary committed
62
    def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
63
64
        """
        Add data for set of blocks
Mostofa Patwary's avatar
Mostofa Patwary committed
65
        :param row_id: 1D array of unique int ids for the blocks
Neel Kant's avatar
Neel Kant committed
66
        :param block_embeds: 2D array of embeddings of the blocks
67
            In the case of retriever this will be [start_idx, end_idx, doc_idx]
Neel Kant's avatar
Neel Kant committed
68
        """
Mostofa Patwary's avatar
Mostofa Patwary committed
69
        for idx, embed in zip(row_id, block_embeds):
Neel Kant's avatar
Neel Kant committed
70
71
72
73
74
75
            if not allow_overwrite and idx in self.embed_data:
                raise ValueError("Unexpectedly tried to overwrite block data")

            self.embed_data[idx] = np.float16(embed)

    def save_shard(self):
76
77
78
        """
        Save the block data that was created this in this process
        """
Neel Kant's avatar
Neel Kant committed
79
80
81
82
        if not os.path.isdir(self.temp_dir_name):
            os.makedirs(self.temp_dir_name, exist_ok=True)

        # save the data for each shard
83
84
        with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
            as writer:
Mostofa Patwary's avatar
Mostofa Patwary committed
85
            pickle.dump(self.state(), writer)
Neel Kant's avatar
Neel Kant committed
86
87

    def merge_shards_and_save(self):
88
        #Combine all the shards made using save_shard
Neel Kant's avatar
Neel Kant committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        shard_names = os.listdir(self.temp_dir_name)
        seen_own_shard = False

        for fname in os.listdir(self.temp_dir_name):
            shard_rank = int(os.path.splitext(fname)[0])
            if shard_rank == self.rank:
                seen_own_shard = True
                continue

            with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
                data = pickle.load(f)
                old_size = len(self.embed_data)
                shard_size = len(data['embed_data'])

103
104
                # add the shard's data and check to make sure there
                # is no overlap
Neel Kant's avatar
Neel Kant committed
105
106
107
108
109
110
                self.embed_data.update(data['embed_data'])
                assert len(self.embed_data) == old_size + shard_size

        assert seen_own_shard

        # save the consolidated shards and remove temporary directory
Mostofa Patwary's avatar
Mostofa Patwary committed
111
        with open(self.embedding_path, 'wb') as final_file:
Neel Kant's avatar
Neel Kant committed
112
113
114
115
116
117
118
119
            pickle.dump(self.state(), final_file)
        shutil.rmtree(self.temp_dir_name, ignore_errors=True)

        print("Finished merging {} shards for a total of {} embeds".format(
            len(shard_names), len(self.embed_data)), flush=True)


class FaissMIPSIndex(object):
Mostofa Patwary's avatar
Mostofa Patwary committed
120
121
122
123
    """
    Wrapper object for a BlockData which similarity search via FAISS under the hood
    """
    def __init__(self, embed_size, embed_data=None, use_gpu=False):
Neel Kant's avatar
Neel Kant committed
124
        self.embed_size = embed_size
Mostofa Patwary's avatar
Mostofa Patwary committed
125
        self.embed_data = embed_data
Neel Kant's avatar
Neel Kant committed
126
127
        self.use_gpu = use_gpu

Mostofa Patwary's avatar
Mostofa Patwary committed
128
129
        self.mips_index = None
        self._set_mips_index()
Neel Kant's avatar
Neel Kant committed
130

Mostofa Patwary's avatar
Mostofa Patwary committed
131
132
133
134
135
    def _set_mips_index(self):
        """
        Create a Faiss Flat index with inner product as the metric
        to search against
        """
Neel Kant's avatar
Neel Kant committed
136
137
138
139
140
        try:
            import faiss
        except ImportError:
            raise Exception("Error: Please install faiss to use FaissMIPSIndex")

141
        if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
142
            print("\n> Building index", flush=True)
Mostofa Patwary's avatar
Mostofa Patwary committed
143
144

        cpu_index = faiss.IndexFlatIP(self.embed_size)
Neel Kant's avatar
Neel Kant committed
145
146
147

        if self.use_gpu:
            # create resources and config for GpuIndex
Mostofa Patwary's avatar
Mostofa Patwary committed
148
149
            config = faiss.GpuMultipleClonerOptions()
            config.shard = True
Neel Kant's avatar
Neel Kant committed
150
            config.useFloat16 = True
Mostofa Patwary's avatar
Mostofa Patwary committed
151
152
            gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
            self.mips_index = faiss.IndexIDMap(gpu_index)
153
            if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
154
                print(">> Initialized index on GPU", flush=True)
Neel Kant's avatar
Neel Kant committed
155
156
        else:
            # CPU index supports IDs so wrap with IDMap
Mostofa Patwary's avatar
Mostofa Patwary committed
157
            self.mips_index = faiss.IndexIDMap(cpu_index)
158
            if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
159
                print(">> Initialized index on CPU", flush=True)
Neel Kant's avatar
Neel Kant committed
160

Mostofa Patwary's avatar
Mostofa Patwary committed
161
162
163
164
        # if we were constructed with a BlockData, then automatically load it
        # when the FAISS structure is built
        if self.embed_data is not None:
            self.add_embed_data(self.embed_data)
Neel Kant's avatar
Neel Kant committed
165
166

    def reset_index(self):
Mostofa Patwary's avatar
Mostofa Patwary committed
167
168
        """Delete existing index and create a new"""
        del self.mips_index
Neel Kant's avatar
Neel Kant committed
169
170

        # reset the block data so that _set_block_index will reload it as well
Mostofa Patwary's avatar
Mostofa Patwary committed
171
172
173
174
175
176
        if self.embed_data is not None:
            embed_data_path = self.embed_data.embedding_path
            del self.embed_data
            self.embed_data = OpenRetreivalDataStore(embed_data_path)

        self._set_mips_index()
Neel Kant's avatar
Neel Kant committed
177

Mostofa Patwary's avatar
Mostofa Patwary committed
178
179
180
    def update_index(self):
        """Delete existing index and create a new"""
        del self.mips_index
Neel Kant's avatar
Neel Kant committed
181

Mostofa Patwary's avatar
Mostofa Patwary committed
182
183
184
185
186
187
        # reset the block data so that _set_mips_index will reload it as well
        if self.embed_data is not None:
            self.embed_data.load_from_file()
        self._set_mips_index()

    def add_embed_data(self, all_embed_data):
Neel Kant's avatar
Neel Kant committed
188
        """Add the embedding of each block to the underlying FAISS index"""
Neel Kant's avatar
Neel Kant committed
189
190

        # this assumes the embed_data is a dict : {int: np.array<float>}
Mostofa Patwary's avatar
Mostofa Patwary committed
191
        block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
Neel Kant's avatar
Neel Kant committed
192

Mostofa Patwary's avatar
Mostofa Patwary committed
193
194
195
196
        # the embeddings have to be entered in as float32 even though the math
        # internally is done with float16.
        embeds_arr = np.float32(np.array(block_embeds))
        indices_arr = np.array(block_indices)
Neel Kant's avatar
Neel Kant committed
197

Neel Kant's avatar
Neel Kant committed
198
        # we no longer need the embedding data since it's in the index now
Mostofa Patwary's avatar
Mostofa Patwary committed
199
        all_embed_data.clear()
Neel Kant's avatar
Neel Kant committed
200

Mostofa Patwary's avatar
Mostofa Patwary committed
201
        self.mips_index.add_with_ids(embeds_arr, indices_arr)
Neel Kant's avatar
Neel Kant committed
202

203
        if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
204
            print(">>> Finished adding block data to index", flush=True)
Neel Kant's avatar
Neel Kant committed
205
206

    def search_mips_index(self, query_embeds, top_k, reconstruct=True):
Mostofa Patwary's avatar
Mostofa Patwary committed
207
208
        """
        Get the top-k blocks by the index distance metric.
Neel Kant's avatar
Neel Kant committed
209

Mostofa Patwary's avatar
Mostofa Patwary committed
210
211
212
213
        :param reconstruct: if True: return a [num_queries x k x embed_dim]
                                array of blocks
                            if False: return [num_queries x k] array of
                                distances, and another for indices
Neel Kant's avatar
Neel Kant committed
214
215
        """
        query_embeds = np.float32(detach(query_embeds))
Neel Kant's avatar
Neel Kant committed
216
217
218

        if reconstruct:
            # get the vectors themselves
Mostofa Patwary's avatar
Mostofa Patwary committed
219
220
            top_k_block_embeds = self.mips_index.search_and_reconstruct(\
                query_embeds, top_k)
Neel Kant's avatar
Neel Kant committed
221
222
223
            return top_k_block_embeds
        else:
            # get distances and indices of closest vectors
Mostofa Patwary's avatar
Mostofa Patwary committed
224
            distances, block_indices = self.mips_index.search(query_embeds, top_k)
Neel Kant's avatar
Neel Kant committed
225
            return distances, block_indices