index.py 1.56 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

import abc
import numpy as np
import os
import torch

from tools.retro.external_libs import faiss

from .utils import get_index_dir


class Index(abc.ABC):

    '''Abstract base class for indexes.

    *Note* : While currently only Faiss-based classes are implemented, in the
    future, this class will be extended with other types of indexes that have
    different performance-accuracy trade-offs.

    The primary methods to override are:
    - train() : Train index on the sampled training chunks.
    - add() : Add all training chunks to index.
    '''

    @classmethod
    def c_verbose(cls, index, v):
        '''Make index object verbose.'''
        assert isinstance(v, bool)
        faiss.ParameterSpace().set_index_parameter(index, "verbose", v)

    def get_empty_index_path(self):
        return os.path.join(get_index_dir(), "empty.faissindex")

    def get_empty_index(self):
        return faiss.read_index(self.get_empty_index_path())

    def get_added_index_path(self):
        return os.path.join(get_index_dir(), "added.faissindex")

    def get_added_index(self):
        return faiss.read_index(self.get_added_index_path())

    @abc.abstractmethod
    def train(self, *args):
        pass

    @abc.abstractmethod
    def add(self, *args):
        pass

    def embed_text_dataset_block(self, embedder, text_dataset, _range):
        '''Embed a range of a text dataset.'''
        sub_dataset = torch.utils.data.Subset(text_dataset, range(*_range))
        return embedder.embed_text_dataset(sub_dataset)