faiss_base.py 3.94 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

"""
This class implements a simple, un-optimized wrapper around a Faiss index, that
implements the Index interface (see ..index.py). While this class is
instantiable, it is meant to be extended with optimizations in classes that
inherit from this class (see FaissParAddIndex, for an example).
"""

from datetime import timedelta
liangjing's avatar
v1  
liangjing committed
11
import numpy as np
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
12
13
14
15
16
17
18
import os
import torch
from tqdm import tqdm

from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder
from tools.retro.external_libs import faiss
liangjing's avatar
v1  
liangjing committed
19
20
21
22
23
from tools.retro.index.index import Index
from tools.retro.index.utils import (
    get_training_data_merged_path,
    num_samples_to_block_ranges,
)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
24
25
26
27


class FaissBaseIndex(Index):

liangjing's avatar
v1  
liangjing committed
28
    def _train(self):
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        '''Train index (rank 0's method).'''

        args = get_retro_args()

        assert torch.distributed.get_rank() == 0

        # Set num threads (torch.distributed reset it to 1).
        # faiss.omp_set_num_threads(32)
        faiss.omp_set_num_threads(64)
        # faiss.omp_set_num_threads(128)

        empty_index_path = self.get_empty_index_path()

        # Index already exists? -> return.
        if os.path.isfile(empty_index_path):
            return

        # Load data.
liangjing's avatar
v1  
liangjing committed
47
48
49
50
51
52
        merged_path = get_training_data_merged_path()
        inp = np.memmap(
	    merged_path,
            dtype = "f4",
	    mode = "r",
        ).reshape((-1, args.hidden_size))
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
53
54
55
56
57
58

        # Init index.
        index = faiss.index_factory(args.retro_index_nfeats,
                                    args.retro_index_str)

        # Move to GPU.
liangjing's avatar
v1  
liangjing committed
59
        print("> move faiss index to gpu.")
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
60
61
62
63
        index_ivf = faiss.extract_index_ivf(index)
        clustering_index = \
            faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
        index_ivf.clustering_index = clustering_index
liangjing's avatar
v1  
liangjing committed
64
        print("> finished moving to gpu.")
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
65
66
67
68
69
70
71
72
73
74
75
        self.c_verbose(index, True)
        self.c_verbose(index_ivf, True)
        self.c_verbose(index_ivf.quantizer, True)
        self.c_verbose(index_ivf.clustering_index, True)

        # Train index.
        index.train(inp)

        # Save index.
        faiss.write_index(index, empty_index_path)

liangjing's avatar
v1  
liangjing committed
76
    def train(self):
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
77
78
79
80
        '''Train index.'''

        # Single process only.
        if torch.distributed.get_rank() == 0:
liangjing's avatar
v1  
liangjing committed
81
            self._train()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

        torch.distributed.barrier()

    def _add(self, text_dataset):
        '''Add to index (rank 0's method).'''

        assert torch.distributed.get_rank() == 0

        args = get_retro_args()

        dataset_sample_ranges = num_samples_to_block_ranges(len(text_dataset))

        # Set num threads (torch.distributed reset it to 1).
        faiss.omp_set_num_threads(64)

        # Bert embedder.
        embedder = BertEmbedder(args.retro_bert_batch_size,
                                args.retro_bert_max_chunk_length,
                                args.bert_embedder_type)

        # Empty/added index paths.
        empty_index_path = self.get_empty_index_path()
        added_index_path = self.get_added_index_path()

        # Skip adding, if index exists.
        if os.path.isfile(added_index_path):
            return

        # Read trained index.
        index = faiss.read_index(empty_index_path)

        # Iterate data blocks & add.
        for sample_range in tqdm(dataset_sample_ranges, "faiss_base.add"):

            # Embed text.
            embeds = self.embed_text_dataset_block(
                embedder, text_dataset, sample_range)

            # Add to index.
            index.add(embeds)

        # Write index.
        faiss.write_index(index, added_index_path)

    def add(self, text_dataset):
        '''Add to index.'''

        # Single process only.
        if torch.distributed.get_rank() == 0:
            self._add(text_dataset)

        # Wait for rank 0.
        torch.distributed.barrier()

        # Get output index path, for return.
        return self.get_added_index_path()