factory.py 591 Bytes
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

from .indexes import FaissBaseIndex, FaissParallelAddIndex


class IndexFactory:
    '''Get index.

    Index type generally read from argument '--retro-index-ty'.
    '''

    @classmethod
    def get_index_class(cls, index_type):
        return {
            "faiss-base" : FaissBaseIndex,
            "faiss-par-add" : FaissParallelAddIndex,
        }[index_type]

    @classmethod
    def get_index(cls, index_type):
        index_class = cls.get_index_class(index_type)
        index = index_class()
        return index