faiss_base.py 3.69 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

"""
This class implements a simple, un-optimized wrapper around a Faiss index, that
implements the Index interface (see ..index.py). While this class is
instantiable, it is meant to be extended with optimizations in classes that
inherit from this class (see FaissParAddIndex, for an example).
"""

from datetime import timedelta
import os
import torch
from tqdm import tqdm

from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder
from tools.retro.external_libs import faiss
from tools.retro.index import Index
from tools.retro.index.utils import num_samples_to_block_ranges


class FaissBaseIndex(Index):

    def _train(self, input_data_loader):
        '''Train index (rank 0's method).'''

        args = get_retro_args()

        assert torch.distributed.get_rank() == 0

        # Set num threads (torch.distributed reset it to 1).
        # faiss.omp_set_num_threads(32)
        faiss.omp_set_num_threads(64)
        # faiss.omp_set_num_threads(128)

        empty_index_path = self.get_empty_index_path()

        # Index already exists? -> return.
        if os.path.isfile(empty_index_path):
            return

        # Load data.
        inp = input_data_loader()

        # Init index.
        index = faiss.index_factory(args.retro_index_nfeats,
                                    args.retro_index_str)

        # Move to GPU.
        index_ivf = faiss.extract_index_ivf(index)
        clustering_index = \
            faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
        index_ivf.clustering_index = clustering_index
        self.c_verbose(index, True)
        self.c_verbose(index_ivf, True)
        self.c_verbose(index_ivf.quantizer, True)
        self.c_verbose(index_ivf.clustering_index, True)

        # Train index.
        index.train(inp)

        # Save index.
        faiss.write_index(index, empty_index_path)

    def train(self, input_data_loader):
        '''Train index.'''

        # Single process only.
        if torch.distributed.get_rank() == 0:
            self._train(input_data_loader)

        torch.distributed.barrier()

    def _add(self, text_dataset):
        '''Add to index (rank 0's method).'''

        assert torch.distributed.get_rank() == 0

        args = get_retro_args()

        dataset_sample_ranges = num_samples_to_block_ranges(len(text_dataset))

        # Set num threads (torch.distributed reset it to 1).
        faiss.omp_set_num_threads(64)

        # Bert embedder.
        embedder = BertEmbedder(args.retro_bert_batch_size,
                                args.retro_bert_max_chunk_length,
                                args.bert_embedder_type)

        # Empty/added index paths.
        empty_index_path = self.get_empty_index_path()
        added_index_path = self.get_added_index_path()

        # Skip adding, if index exists.
        if os.path.isfile(added_index_path):
            return

        # Read trained index.
        index = faiss.read_index(empty_index_path)

        # Iterate data blocks & add.
        for sample_range in tqdm(dataset_sample_ranges, "faiss_base.add"):

            # Embed text.
            embeds = self.embed_text_dataset_block(
                embedder, text_dataset, sample_range)

            # Add to index.
            index.add(embeds)

        # Write index.
        faiss.write_index(index, added_index_path)

    def add(self, text_dataset):
        '''Add to index.'''

        # Single process only.
        if torch.distributed.get_rank() == 0:
            self._add(text_dataset)

        # Wait for rank 0.
        torch.distributed.barrier()

        # Get output index path, for return.
        return self.get_added_index_path()