Commit 395d2ce6 authored by huchen's avatar huchen
Browse files

init the faiss for rocm

parent 5ded39f5
#!/bin/sh
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
set -e
wget -O - https://github.com/Kitware/CMake/releases/download/v3.17.1/cmake-3.17.1-Linux-x86_64.tar.gz | tar xzf -
cp -R cmake-3.17.1-Linux-x86_64/* $PREFIX
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
{% set version = environ.get('GIT_DESCRIBE_TAG').lstrip('v') %}
{% set suffix = "_nightly" if environ.get('PACKAGE_TYPE') == 'nightly' else "" %}
{% set number = GIT_DESCRIBE_NUMBER %}
package:
name: faiss-pkg
version: {{ version }}
build:
number: {{ number }}
about:
home: https://github.com/facebookresearch/faiss
license: MIT
license_family: MIT
license_file: LICENSE
summary: A library for efficient similarity search and clustering of dense vectors.
source:
git_url: ../../
outputs:
- name: libfaiss
script: build-lib.sh
build:
string: "h{{ PKG_HASH }}_{{ number }}_cuda{{ cudatoolkit }}{{ suffix }}"
run_exports:
- {{ pin_compatible('libfaiss', exact=True) }}
script_env:
- CUDA_ARCHS
requirements:
build:
- {{ compiler('cxx') }}
- llvm-openmp # [osx]
- cmake >=3.18
- make # [not win]
host:
- mkl =2018
- cudatoolkit {{ cudatoolkit }}
run:
- mkl >=2018 # [not win]
- mkl >=2018,<2021 # [win]
- {{ pin_compatible('cudatoolkit', max_pin='x.x') }}
test:
commands:
- test -f $PREFIX/lib/libfaiss.so # [linux]
- test -f $PREFIX/lib/libfaiss.dylib # [osx]
- conda inspect linkages -p $PREFIX $PKG_NAME # [not win]
- conda inspect objects -p $PREFIX $PKG_NAME # [osx]
- name: faiss-gpu
script: build-pkg.sh
build:
string: "py{{ PY_VER }}_h{{ PKG_HASH }}_{{ number }}_cuda{{ cudatoolkit }}{{ suffix }}"
requirements:
build:
- {{ compiler('cxx') }}
- swig
- cmake >=3.17
- make # [not win]
host:
- python {{ python }}
- numpy =1.11
- {{ pin_subpackage('libfaiss', exact=True) }}
run:
- python {{ python }}
- numpy >=1.11,<2
- {{ pin_subpackage('libfaiss', exact=True) }}
test:
requires:
- numpy
- scipy
- pytorch
commands:
- python -m unittest discover tests/
- cp tests/common_faiss_tests.py faiss/gpu/test
- python -m unittest discover faiss/gpu/test/
- sh test_cpu_dispatch.sh # [linux]
files:
- test_cpu_dispatch.sh # [linux]
source_files:
- tests/
- faiss/gpu/test/
#!/bin/sh
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
set -e
FAISS_DISABLE_CPU_FEATURES=AVX2 LD_DEBUG=libs python -c "import faiss" 2>&1 | grep libfaiss.so
LD_DEBUG=libs python -c "import faiss" 2>&1 | grep libfaiss_avx2.so
:: Copyright (c) Facebook, Inc. and its affiliates.
::
:: This source code is licensed under the MIT license found in the
:: LICENSE file in the root directory of this source tree.
:: Build libfaiss.so.
cmake -B _build ^
-T v141 ^
-A x64 ^
-G "Visual Studio 16 2019" ^
-DBUILD_SHARED_LIBS=ON ^
-DBUILD_TESTING=OFF ^
-DFAISS_ENABLE_GPU=OFF ^
-DFAISS_ENABLE_PYTHON=OFF ^
-DBLA_VENDOR=Intel10_64_dyn ^
.
if %errorlevel% neq 0 exit /b %errorlevel%
cmake --build _build --config Release -j %CPU_COUNT%
if %errorlevel% neq 0 exit /b %errorlevel%
cmake --install _build --config Release --prefix %PREFIX%
if %errorlevel% neq 0 exit /b %errorlevel%
#!/bin/sh
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
set -e
# Build libfaiss.so/libfaiss_avx2.so.
cmake -B _build \
-DBUILD_SHARED_LIBS=ON \
-DBUILD_TESTING=OFF \
-DFAISS_OPT_LEVEL=avx2 \
-DFAISS_ENABLE_GPU=OFF \
-DFAISS_ENABLE_PYTHON=OFF \
-DBLA_VENDOR=Intel10_64lp \
-DCMAKE_INSTALL_LIBDIR=lib \
-DCMAKE_BUILD_TYPE=Release .
make -C _build -j $CPU_COUNT faiss faiss_avx2
cmake --install _build --prefix $PREFIX
cmake --install _build --prefix _libfaiss_stage/
:: Copyright (c) Facebook, Inc. and its affiliates.
::
:: This source code is licensed under the MIT license found in the
:: LICENSE file in the root directory of this source tree.
:: Build vanilla version (no avx).
cmake -B _build_python_%PY_VER% ^
-T v141 ^
-A x64 ^
-G "Visual Studio 16 2019" ^
-DFAISS_ENABLE_GPU=OFF ^
-DPython_EXECUTABLE=%PYTHON% ^
faiss/python
if %errorlevel% neq 0 exit /b %errorlevel%
cmake --build _build_python_%PY_VER% --config Release -j %CPU_COUNT%
if %errorlevel% neq 0 exit /b %errorlevel%
:: Build actual python module.
cd _build_python_%PY_VER%/
%PYTHON% setup.py install --single-version-externally-managed --record=record.txt --prefix=%PREFIX%
if %errorlevel% neq 0 exit /b %errorlevel%
#!/bin/sh
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
set -e
# Build swigfaiss.so/swigfaiss_avx2.so.
cmake -B _build_python_${PY_VER} \
-Dfaiss_ROOT=_libfaiss_stage/ \
-DFAISS_OPT_LEVEL=avx2 \
-DFAISS_ENABLE_GPU=OFF \
-DCMAKE_BUILD_TYPE=Release \
-DPython_EXECUTABLE=$PYTHON \
faiss/python
make -C _build_python_${PY_VER} -j $CPU_COUNT swigfaiss swigfaiss_avx2
# Build actual python module.
cd _build_python_${PY_VER}/
$PYTHON setup.py install --single-version-externally-managed --record=record.txt --prefix=$PREFIX
#!/bin/sh#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
set -e
wget -O - https://github.com/Kitware/CMake/releases/download/v3.17.1/cmake-3.17.1-Linux-x86_64.tar.gz | tar xzf -
cp -R cmake-3.17.1-Linux-x86_64/* $PREFIX
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
{% set version = environ.get('GIT_DESCRIBE_TAG').lstrip('v') %}
{% set suffix = "_nightly" if environ.get('PACKAGE_TYPE') == 'nightly' else "" %}
{% set number = GIT_DESCRIBE_NUMBER %}
package:
name: faiss-pkg
version: {{ version }}
build:
number: {{ number }}
about:
home: https://github.com/facebookresearch/faiss
license: MIT
license_family: MIT
license_file: LICENSE
summary: A library for efficient similarity search and clustering of dense vectors.
source:
git_url: ../../
outputs:
- name: libfaiss
script: build-lib.sh # [not win]
script: build-lib.bat # [win]
build:
string: "h{{ PKG_HASH }}_{{ number }}_cpu{{ suffix }}"
run_exports:
- {{ pin_compatible('libfaiss', exact=True) }}
requirements:
build:
- {{ compiler('cxx') }}
- llvm-openmp # [osx]
- cmake >=3.17
- make # [not win]
host:
- mkl =2018
run:
- mkl >=2018 # [not win]
- mkl >=2018,<2021 # [win]
test:
commands:
- test -f $PREFIX/lib/libfaiss$SHLIB_EXT # [not win]
- test -f $PREFIX/lib/libfaiss_avx2$SHLIB_EXT # [not win]
- conda inspect linkages -p $PREFIX $PKG_NAME # [not win]
- conda inspect objects -p $PREFIX $PKG_NAME # [osx]
- name: faiss-cpu
script: build-pkg.sh # [not win]
script: build-pkg.bat # [win]
build:
string: "py{{ PY_VER }}_h{{ PKG_HASH }}_{{ number }}_cpu{{ suffix }}"
requirements:
build:
- {{ compiler('cxx') }}
- swig
- cmake >=3.17
- make # [not win]
host:
- python {{ python }}
- numpy =1.11
- {{ pin_subpackage('libfaiss', exact=True) }}
run:
- python {{ python }}
- numpy >=1.11,<2
- {{ pin_subpackage('libfaiss', exact=True) }}
test:
requires:
- numpy
- scipy
- pytorch
commands:
- python -X faulthandler -m unittest discover -v -s tests -p "test_*"
- python -X faulthandler -m unittest discover -v -s tests -p "torch_*"
- sh test_cpu_dispatch.sh # [linux]
files:
- test_cpu_dispatch.sh # [linux]
source_files:
- tests/
#!/bin/sh
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
set -e
FAISS_DISABLE_CPU_FEATURES=AVX2 LD_DEBUG=libs python -c "import faiss" 2>&1 | grep libfaiss.so
LD_DEBUG=libs python -c "import faiss" 2>&1 | grep libfaiss_avx2.so
# The contrib modules
The contrib directory contains helper modules for Faiss for various tasks.
## Code structure
The contrib directory gets compiled in the module faiss.contrib.
Note that although some of the modules may depend on additional modules (eg. GPU Faiss, pytorch, hdf5), they are not necessarily compiled in to avoid adding dependencies. It is the user's responsibility to provide them.
In contrib, we are progressively dropping python2 support.
## List of contrib modules
### rpc.py
A very simple Remote Procedure Call library, where function parameters and results are pickled, for use with client_server.py
### client_server.py
The server handles requests to a Faiss index. The client calls the remote index.
This is mainly to shard datasets over several machines, see [Distributd index](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#distributed-index)
### ondisk.py
Encloses the main logic to merge indexes into an on-disk index.
See [On-disk storage](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#on-disk-storage)
### exhaustive_search.py
Computes the ground-truth search results for a dataset that possibly does not fit in RAM. Uses GPU if available.
Tested in `tests/test_contrib.TestComputeGT`
### torch_utils.py
Interoperability functions for pytorch and Faiss: Importing this will allow pytorch Tensors (CPU or GPU) to be used as arguments to Faiss indexes and other functions. Torch GPU tensors can only be used with Faiss GPU indexes. If this is imported with a package that supports Faiss GPU, the necessary stream synchronization with the current pytorch stream will be automatically performed.
Numpy ndarrays can continue to be used in the Faiss python interface after importing this file. All arguments must be uniformly either numpy ndarrays or Torch tensors; no mixing is allowed.
Tested in `tests/test_contrib_torch.py` (CPU) and `gpu/test/test_contrib_torch_gpu.py` (GPU).
### inspect_tools.py
Functions to inspect C++ objects wrapped by SWIG. Most often this just means reading
fields and converting them to the proper python array.
### ivf_tools.py
A few functions to override the coarse quantizer in IVF, providing additional flexibility for assignment.
### datasets.py
(may require h5py)
Defintion of how to access data for some standard datsets.
### factory_tools.py
Functions related to factory strings.
### evaluation.py
A few non-trivial evaluation functions for search results
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from multiprocessing.dummy import Pool as ThreadPool
import faiss
from typing import List, Tuple
from . import rpc
############################################################
# Server implementation
############################################################
class SearchServer(rpc.Server):
""" Assign version that can be exposed via RPC """
def __init__(self, s: int, index: faiss.Index):
rpc.Server.__init__(self, s)
self.index = index
self.index_ivf = faiss.extract_index_ivf(index)
def set_nprobe(self, nprobe: int) -> int:
""" set nprobe field """
self.index_ivf.nprobe = nprobe
def get_ntotal(self) -> int:
return self.index.ntotal
def __getattr__(self, f):
# all other functions get forwarded to the index
return getattr(self.index, f)
def run_index_server(index: faiss.Index, port: int, v6: bool = False):
""" serve requests for that index forerver """
rpc.run_server(
lambda s: SearchServer(s, index),
port, v6=v6)
############################################################
# Client implementation
############################################################
class ClientIndex:
"""manages a set of distance sub-indexes. The sub_indexes search a
subset of the inverted lists. Searches are merged afterwards
"""
def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False):
""" connect to a series of (host, port) pairs """
self.sub_indexes = []
for machine, port in machine_ports:
self.sub_indexes.append(rpc.Client(machine, port, v6))
self.ni = len(self.sub_indexes)
# pool of threads. Each thread manages one sub-index.
self.pool = ThreadPool(self.ni)
# test connection...
self.ntotal = self.get_ntotal()
self.verbose = False
def set_nprobe(self, nprobe: int) -> None:
self.pool.map(
lambda idx: idx.set_nprobe(nprobe),
self.sub_indexes
)
def set_omp_num_threads(self, nt: int) -> None:
self.pool.map(
lambda idx: idx.set_omp_num_threads(nt),
self.sub_indexes
)
def get_ntotal(self) -> None:
return sum(self.pool.map(
lambda idx: idx.get_ntotal(),
self.sub_indexes
))
def search(self, x, k: int):
rh = faiss.ResultHeap(x.shape[0], k)
for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes):
rh.add_result(Di, Ii)
rh.finalize()
return rh.D, rh.I
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import faiss
from .vecs_io import fvecs_read, ivecs_read, bvecs_mmap, fvecs_mmap
from .exhaustive_search import knn
class Dataset:
""" Generic abstract class for a test dataset """
def __init__(self):
""" the constructor should set the following fields: """
self.d = -1
self.metric = 'L2' # or IP
self.nq = -1
self.nb = -1
self.nt = -1
def get_queries(self):
""" return the queries as a (nq, d) array """
raise NotImplementedError()
def get_train(self, maxtrain=None):
""" return the queries as a (nt, d) array """
raise NotImplementedError()
def get_database(self):
""" return the queries as a (nb, d) array """
raise NotImplementedError()
def database_iterator(self, bs=128, split=(1, 0)):
"""returns an iterator on database vectors.
bs is the number of vectors per batch
split = (nsplit, rank) means the dataset is split in nsplit
shards and we want shard number rank
The default implementation just iterates over the full matrix
returned by get_dataset.
"""
xb = self.get_database()
nsplit, rank = split
i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
for j0 in range(i0, i1, bs):
yield xb[j0: min(j0 + bs, i1)]
def get_groundtruth(self, k=None):
""" return the ground truth for k-nearest neighbor search """
raise NotImplementedError()
def get_groundtruth_range(self, thresh=None):
""" return the ground truth for range search """
raise NotImplementedError()
def __str__(self):
return (f"dataset in dimension {self.d}, with metric {self.metric}, "
f"size: Q {self.nq} B {self.nb} T {self.nt}")
def check_sizes(self):
""" runs the previous and checks the sizes of the matrices """
assert self.get_queries().shape == (self.nq, self.d)
if self.nt > 0:
xt = self.get_train(maxtrain=123)
assert xt.shape == (123, self.d), "shape=%s" % (xt.shape, )
assert self.get_database().shape == (self.nb, self.d)
assert self.get_groundtruth(k=13).shape == (self.nq, 13)
class SyntheticDataset(Dataset):
"""A dataset that is not completely random but still challenging to
index
"""
def __init__(self, d, nt, nb, nq, metric='L2', seed=1338):
Dataset.__init__(self)
self.d, self.nt, self.nb, self.nq = d, nt, nb, nq
d1 = 10 # intrinsic dimension (more or less)
n = nb + nt + nq
rs = np.random.RandomState(seed)
x = rs.normal(size=(n, d1))
x = np.dot(x, rs.rand(d1, d))
# now we have a d1-dim ellipsoid in d-dimensional space
# higher factor (>4) -> higher frequency -> less linear
x = x * (rs.rand(d) * 4 + 0.1)
x = np.sin(x)
x = x.astype('float32')
self.metric = metric
self.xt = x[:nt]
self.xb = x[nt:nt + nb]
self.xq = x[nt + nb:]
def get_queries(self):
return self.xq
def get_train(self, maxtrain=None):
maxtrain = maxtrain if maxtrain is not None else self.nt
return self.xt[:maxtrain]
def get_database(self):
return self.xb
def get_groundtruth(self, k=100):
return knn(
self.xq, self.xb, k,
faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT
)[1]
############################################################################
# The following datasets are a few standard open-source datasets
# they should be stored in a directory, and we start by guessing where
# that directory is
############################################################################
for dataset_basedir in (
'/datasets01/simsearch/041218/',
'/mnt/vol/gfsai-flash3-east/ai-group/datasets/simsearch/'):
if os.path.exists(dataset_basedir):
break
else:
# users can link their data directory to `./data`
dataset_basedir = 'data/'
class DatasetSIFT1M(Dataset):
"""
The original dataset is available at: http://corpus-texmex.irisa.fr/
(ANN_SIFT1M)
"""
def __init__(self):
Dataset.__init__(self)
self.d, self.nt, self.nb, self.nq = 128, 100000, 1000000, 10000
self.basedir = dataset_basedir + 'sift1M/'
def get_queries(self):
return fvecs_read(self.basedir + "sift_query.fvecs")
def get_train(self, maxtrain=None):
maxtrain = maxtrain if maxtrain is not None else self.nt
return fvecs_read(self.basedir + "sift_learn.fvecs")[:maxtrain]
def get_database(self):
return fvecs_read(self.basedir + "sift_base.fvecs")
def get_groundtruth(self, k=None):
gt = ivecs_read(self.basedir + "sift_groundtruth.ivecs")
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
def sanitize(x):
return np.ascontiguousarray(x, dtype='float32')
class DatasetBigANN(Dataset):
"""
The original dataset is available at: http://corpus-texmex.irisa.fr/
(ANN_SIFT1B)
"""
def __init__(self, nb_M=1000):
Dataset.__init__(self)
assert nb_M in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000)
self.nb_M = nb_M
nb = nb_M * 10**6
self.d, self.nt, self.nb, self.nq = 128, 10**8, nb, 10000
self.basedir = dataset_basedir + 'bigann/'
def get_queries(self):
return sanitize(bvecs_mmap(self.basedir + 'bigann_query.bvecs')[:])
def get_train(self, maxtrain=None):
maxtrain = maxtrain if maxtrain is not None else self.nt
return sanitize(bvecs_mmap(self.basedir + 'bigann_learn.bvecs')[:maxtrain])
def get_groundtruth(self, k=None):
gt = ivecs_read(self.basedir + 'gnd/idx_%dM.ivecs' % self.nb_M)
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
def get_database(self):
assert self.nb_M < 100, "dataset too large, use iterator"
return sanitize(bvecs_mmap(self.basedir + 'bigann_base.bvecs')[:self.nb])
def database_iterator(self, bs=128, split=(1, 0)):
xb = bvecs_mmap(self.basedir + 'bigann_base.bvecs')
nsplit, rank = split
i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
for j0 in range(i0, i1, bs):
yield sanitize(xb[j0: min(j0 + bs, i1)])
class DatasetDeep1B(Dataset):
"""
See
https://github.com/facebookresearch/faiss/tree/main/benchs#getting-deep1b
on how to get the data
"""
def __init__(self, nb=10**9):
Dataset.__init__(self)
nb_to_name = {
10**5: '100k',
10**6: '1M',
10**7: '10M',
10**8: '100M',
10**9: '1B'
}
assert nb in nb_to_name
self.d, self.nt, self.nb, self.nq = 96, 358480000, nb, 10000
self.basedir = dataset_basedir + 'deep1b/'
self.gt_fname = "%sdeep%s_groundtruth.ivecs" % (
self.basedir, nb_to_name[self.nb])
def get_queries(self):
return sanitize(fvecs_read(self.basedir + "deep1B_queries.fvecs"))
def get_train(self, maxtrain=None):
maxtrain = maxtrain if maxtrain is not None else self.nt
return sanitize(fvecs_mmap(self.basedir + "learn.fvecs")[:maxtrain])
def get_groundtruth(self, k=None):
gt = ivecs_read(self.gt_fname)
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
def get_database(self):
assert self.nb <= 10**8, "dataset too large, use iterator"
return sanitize(fvecs_mmap(self.basedir + "base.fvecs")[:self.nb])
def database_iterator(self, bs=128, split=(1, 0)):
xb = fvecs_mmap(self.basedir + "base.fvecs")
nsplit, rank = split
i0, i1 = self.nb * rank // nsplit, self.nb * (rank + 1) // nsplit
for j0 in range(i0, i1, bs):
yield sanitize(xb[j0: min(j0 + bs, i1)])
class DatasetGlove(Dataset):
"""
Data from http://ann-benchmarks.com/glove-100-angular.hdf5
"""
def __init__(self, loc=None, download=False):
import h5py
assert not download, "not implemented"
if not loc:
loc = dataset_basedir + 'glove/glove-100-angular.hdf5'
self.glove_h5py = h5py.File(loc, 'r')
# IP and L2 are equivalent in this case, but it is traditionally seen as an IP dataset
self.metric = 'IP'
self.d, self.nt = 100, 0
self.nb = self.glove_h5py['train'].shape[0]
self.nq = self.glove_h5py['test'].shape[0]
def get_queries(self):
xq = np.array(self.glove_h5py['test'])
faiss.normalize_L2(xq)
return xq
def get_database(self):
xb = np.array(self.glove_h5py['train'])
faiss.normalize_L2(xb)
return xb
def get_groundtruth(self, k=None):
gt = self.glove_h5py['neighbors']
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
class DatasetMusic100(Dataset):
"""
get dataset from
https://github.com/stanis-morozov/ip-nsw#dataset
"""
def __init__(self):
Dataset.__init__(self)
self.d, self.nt, self.nb, self.nq = 100, 0, 10**6, 10000
self.metric = 'IP'
self.basedir = dataset_basedir + 'music-100/'
def get_queries(self):
xq = np.fromfile(self.basedir + 'query_music100.bin', dtype='float32')
xq = xq.reshape(-1, 100)
return xq
def get_database(self):
xb = np.fromfile(self.basedir + 'database_music100.bin', dtype='float32')
xb = xb.reshape(-1, 100)
return xb
def get_groundtruth(self, k=None):
gt = np.load(self.basedir + 'gt.npy')
if k is not None:
assert k <= 100
gt = gt[:, :k]
return gt
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import unittest
from multiprocessing.dummy import Pool as ThreadPool
###############################################################
# Simple functions to evaluate knn results
def knn_intersection_measure(I1, I2):
""" computes the intersection measure of two result tables
"""
nq, rank = I1.shape
assert I2.shape == (nq, rank)
ninter = sum(
np.intersect1d(I1[i], I2[i]).size
for i in range(nq)
)
return ninter / I1.size
###############################################################
# Range search results can be compared with Precision-Recall
def filter_range_results(lims, D, I, thresh):
""" select a set of results """
nq = lims.size - 1
mask = D < thresh
new_lims = np.zeros_like(lims)
for i in range(nq):
new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum()
return new_lims, D[mask], I[mask]
def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"):
"""compute the precision and recall of range search results. The
function does not take the distances into account. """
def ref_result_for(i):
return Iref[lims_ref[i]:lims_ref[i + 1]]
def new_result_for(i):
return Inew[lims_new[i]:lims_new[i + 1]]
nq = lims_ref.size - 1
assert lims_new.size - 1 == nq
ninter = np.zeros(nq, dtype="int64")
def compute_PR_for(q):
# ground truth results for this query
gt_ids = ref_result_for(q)
# results for this query
new_ids = new_result_for(q)
# there are no set functions in numpy so let's do this
inter = np.intersect1d(gt_ids, new_ids)
ninter[q] = len(inter)
# run in a thread pool, which helps in spite of the GIL
pool = ThreadPool(20)
pool.map(compute_PR_for, range(nq))
return counts_to_PR(
lims_ref[1:] - lims_ref[:-1],
lims_new[1:] - lims_new[:-1],
ninter,
mode=mode
)
def counts_to_PR(ngt, nres, ninter, mode="overall"):
""" computes a precision-recall for a ser of queries.
ngt = nb of GT results per query
nres = nb of found results per query
ninter = nb of correct results per query (smaller than nres of course)
"""
if mode == "overall":
ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum()
if nres > 0:
precision = ninter / nres
else:
precision = 1.0
if ngt > 0:
recall = ninter / ngt
elif nres == 0:
recall = 1.0
else:
recall = 0.0
return precision, recall
elif mode == "average":
# average precision and recall over queries
mask = ngt == 0
ngt[mask] = 1
recalls = ninter / ngt
recalls[mask] = (nres[mask] == 0).astype(float)
# avoid division by 0
mask = nres == 0
assert np.all(ninter[mask] == 0)
ninter[mask] = 1
nres[mask] = 1
precisions = ninter / nres
return precisions.mean(), recalls.mean()
else:
raise AssertionError()
def sort_range_res_2(lims, D, I):
""" sort 2 arrays using the first as key """
I2 = np.empty_like(I)
D2 = np.empty_like(D)
nq = len(lims) - 1
for i in range(nq):
l0, l1 = lims[i], lims[i + 1]
ii = I[l0:l1]
di = D[l0:l1]
o = di.argsort()
I2[l0:l1] = ii[o]
D2[l0:l1] = di[o]
return I2, D2
def sort_range_res_1(lims, I):
I2 = np.empty_like(I)
nq = len(lims) - 1
for i in range(nq):
l0, l1 = lims[i], lims[i + 1]
I2[l0:l1] = I[l0:l1]
I2[l0:l1].sort()
return I2
def range_PR_multiple_thresholds(
lims_ref, Iref,
lims_new, Dnew, Inew,
thresholds,
mode="overall", do_sort="ref,new"
):
""" compute precision-recall values for range search results
for several thresholds on the "new" results.
This is to plot PR curves
"""
# ref should be sorted by ids
if "ref" in do_sort:
Iref = sort_range_res_1(lims_ref, Iref)
# new should be sorted by distances
if "new" in do_sort:
Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew)
def ref_result_for(i):
return Iref[lims_ref[i]:lims_ref[i + 1]]
def new_result_for(i):
l0, l1 = lims_new[i], lims_new[i + 1]
return Inew[l0:l1], Dnew[l0:l1]
nq = lims_ref.size - 1
assert lims_new.size - 1 == nq
nt = len(thresholds)
counts = np.zeros((nq, nt, 3), dtype="int64")
def compute_PR_for(q):
gt_ids = ref_result_for(q)
res_ids, res_dis = new_result_for(q)
counts[q, :, 0] = len(gt_ids)
if res_dis.size == 0:
# the rest remains at 0
return
# which offsets we are interested in
nres= np.searchsorted(res_dis, thresholds)
counts[q, :, 1] = nres
if gt_ids.size == 0:
return
# find number of TPs at each stage in the result list
ii = np.searchsorted(gt_ids, res_ids)
ii[ii == len(gt_ids)] = -1
n_ok = np.cumsum(gt_ids[ii] == res_ids)
# focus on threshold points
n_ok = np.hstack(([0], n_ok))
counts[q, :, 2] = n_ok[nres]
pool = ThreadPool(20)
pool.map(compute_PR_for, range(nq))
# print(counts.transpose(2, 1, 0))
precisions = np.zeros(nt)
recalls = np.zeros(nt)
for t in range(nt):
p, r = counts_to_PR(
counts[:, t, 0], counts[:, t, 1], counts[:, t, 2],
mode=mode
)
precisions[t] = p
recalls[t] = r
return precisions, recalls
###############################################################
# Functions that compare search results with a reference result.
# They are intended for use in tests
def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
""" test that knn search results are identical, raise if not """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
# here we have to be careful because of draws
testcase = unittest.TestCase() # because it makes nice error messages
for i in range(len(Iref)):
if np.all(Iref[i] == Inew[i]): # easy case
continue
# we can deduce nothing about the latest line
skip_dis = Dref[i, -1]
for dis in np.unique(Dref):
if dis == skip_dis:
continue
mask = Dref[i, :] == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
def test_ref_range_results(lims_ref, Dref, Iref,
lims_new, Dnew, Inew):
""" compare range search results wrt. a reference result,
throw if it fails """
np.testing.assert_array_equal(lims_ref, lims_new)
nq = len(lims_ref) - 1
for i in range(nq):
l0, l1 = lims_ref[i], lims_ref[i + 1]
Ii_ref = Iref[l0:l1]
Ii_new = Inew[l0:l1]
Di_ref = Dref[l0:l1]
Di_new = Dnew[l0:l1]
if np.all(Ii_ref == Ii_new): # easy
pass
else:
def sort_by_ids(I, D):
o = I.argsort()
return I[o], D[o]
# sort both
(Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref)
(Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new)
np.testing.assert_array_equal(Ii_ref, Ii_new)
np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import faiss
import time
import numpy as np
import logging
LOG = logging.getLogger(__name__)
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
"""Computes the exact KNN search results for a dataset that possibly
does not fit in RAM but for which we have an iterator that
returns it block by block.
"""
LOG.info("knn_ground_truth queries size %s k=%d" % (xq.shape, k))
t0 = time.time()
nq, d = xq.shape
rh = faiss.ResultHeap(nq, k)
index = faiss.IndexFlat(d, metric_type)
if faiss.get_num_gpus():
LOG.info('running on %d GPUs' % faiss.get_num_gpus())
index = faiss.index_cpu_to_all_gpus(index)
# compute ground-truth by blocks, and add to heaps
i0 = 0
for xbi in db_iterator:
ni = xbi.shape[0]
index.add(xbi)
D, I = index.search(xq, k)
I += i0
rh.add_result(D, I)
index.reset()
i0 += ni
LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
rh.finalize()
LOG.info("GT time: %.3f s (%d vectors)" % (time.time() - t0, i0))
return rh.D, rh.I
# knn function used to be here
knn = faiss.knn
def range_search_gpu(xq, r2, index_gpu, index_cpu):
"""GPU does not support range search, so we emulate it with
knn search + fallback to CPU index.
The index_cpu can either be a CPU index or a numpy table that will
be used to construct a Flat index if needed.
"""
nq, d = xq.shape
LOG.debug("GPU search %d queries" % nq)
k = min(index_gpu.ntotal, 1024)
D, I = index_gpu.search(xq, k)
if index_gpu.metric_type == faiss.METRIC_L2:
mask = D[:, k - 1] < r2
else:
mask = D[:, k - 1] > r2
if mask.sum() > 0:
LOG.debug("CPU search remain %d" % mask.sum())
if isinstance(index_cpu, np.ndarray):
# then it in fact an array that we have to make flat
xb = index_cpu
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
index_cpu.add(xb)
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
LOG.debug("combine")
D_res, I_res = [], []
nr = 0
for i in range(nq):
if not mask[i]:
if index_gpu.metric_type == faiss.METRIC_L2:
nv = (D[i, :] < r2).sum()
else:
nv = (D[i, :] > r2).sum()
D_res.append(D[i, :nv])
I_res.append(I[i, :nv])
else:
l0, l1 = lim_remain[nr], lim_remain[nr + 1]
D_res.append(D_remain[l0:l1])
I_res.append(I_remain[l0:l1])
nr += 1
lims = np.cumsum([0] + [len(di) for di in D_res])
return lims, np.hstack(D_res), np.hstack(I_res)
def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,
shard=False, ngpu=-1):
"""Computes the range-search search results for a dataset that possibly
does not fit in RAM but for which we have an iterator that
returns it block by block.
"""
nq, d = xq.shape
t0 = time.time()
xq = np.ascontiguousarray(xq, dtype='float32')
index = faiss.IndexFlat(d, metric_type)
if ngpu == -1:
ngpu = faiss.get_num_gpus()
if ngpu:
LOG.info('running on %d GPUs' % ngpu)
co = faiss.GpuMultipleClonerOptions()
co.shard = shard
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
# compute ground-truth by blocks
i0 = 0
D = [[] for _i in range(nq)]
I = [[] for _i in range(nq)]
for xbi in db_iterator:
ni = xbi.shape[0]
if ngpu > 0:
index_gpu.add(xbi)
lims_i, Di, Ii = range_search_gpu(xq, threshold, index_gpu, xbi)
index_gpu.reset()
else:
index.add(xbi)
lims_i, Di, Ii = index.range_search(xq, threshold)
index.reset()
Ii += i0
for j in range(nq):
l0, l1 = lims_i[j], lims_i[j + 1]
if l1 > l0:
D[j].append(Di[l0:l1])
I[j].append(Ii[l0:l1])
i0 += ni
LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
empty_I = np.zeros(0, dtype='int64')
empty_D = np.zeros(0, dtype='float32')
# import pdb; pdb.set_trace()
D = [(np.hstack(i) if i != [] else empty_D) for i in D]
I = [(np.hstack(i) if i != [] else empty_I) for i in I]
sizes = [len(i) for i in I]
assert len(sizes) == nq
lims = np.zeros(nq + 1, dtype="uint64")
lims[1:] = np.cumsum(sizes)
return lims, np.hstack(D), np.hstack(I)
def threshold_radius_nres(nres, dis, ids, thresh, keep_max=False):
""" select a set of results """
if keep_max:
mask = dis > thresh
else:
mask = dis < thresh
new_nres = np.zeros_like(nres)
o = 0
for i, nr in enumerate(nres):
nr = int(nr) # avoid issues with int64 + uint64
new_nres[i] = mask[o:o + nr].sum()
o += nr
return new_nres, dis[mask], ids[mask]
def threshold_radius(lims, dis, ids, thresh, keep_max=False):
""" restrict range-search results to those below a given radius """
if keep_max:
mask = dis > thresh
else:
mask = dis < thresh
new_lims = np.zeros_like(lims)
n = len(lims) - 1
for i in range(n):
l0, l1 = lims[i], lims[i + 1]
new_lims[i + 1] = new_lims[i] + mask[l0:l1].sum()
return new_lims, dis[mask], ids[mask]
def apply_maxres(res_batches, target_nres, keep_max=False):
"""find radius that reduces number of results to target_nres, and
applies it in-place to the result batches used in
range_search_max_results"""
alldis = np.hstack([dis for _, dis, _ in res_batches])
assert len(alldis) > target_nres
if keep_max:
alldis.partition(len(alldis) - target_nres - 1)
radius = alldis[-1 - target_nres]
else:
alldis.partition(target_nres)
radius = alldis[target_nres]
if alldis.dtype == 'float32':
radius = float(radius)
else:
radius = int(radius)
LOG.debug(' setting radius to %s' % radius)
totres = 0
for i, (nres, dis, ids) in enumerate(res_batches):
nres, dis, ids = threshold_radius_nres(
nres, dis, ids, radius, keep_max=keep_max)
totres += len(dis)
res_batches[i] = nres, dis, ids
LOG.debug(' updated previous results, new nb results %d' % totres)
return radius, totres
def range_search_max_results(index, query_iterator, radius,
max_results=None, min_results=None,
shard=False, ngpu=0, clip_to_min=False):
"""Performs a range search with many queries (given by an iterator)
and adjusts the threshold on-the-fly so that the total results
table does not grow larger than max_results.
If ngpu != 0, the function moves the index to this many GPUs to
speed up search.
"""
# TODO: all result manipulations are in python, should move to C++ if perf
# critical
if min_results is None:
assert max_results is not None
min_results = int(0.8 * max_results)
if max_results is None:
assert min_results is not None
max_results = int(min_results * 1.5)
if ngpu == -1:
ngpu = faiss.get_num_gpus()
if ngpu:
LOG.info('running on %d GPUs' % ngpu)
co = faiss.GpuMultipleClonerOptions()
co.shard = shard
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
t_start = time.time()
t_search = t_post_process = 0
qtot = totres = raw_totres = 0
res_batches = []
for xqi in query_iterator:
t0 = time.time()
if ngpu > 0:
lims_i, Di, Ii = range_search_gpu(xqi, radius, index_gpu, index)
else:
lims_i, Di, Ii = index.range_search(xqi, radius)
nres_i = lims_i[1:] - lims_i[:-1]
raw_totres += len(Di)
qtot += len(xqi)
t1 = time.time()
if xqi.dtype != np.float32:
# for binary indexes
# weird Faiss quirk that returns floats for Hamming distances
Di = Di.astype('int16')
totres += len(Di)
res_batches.append((nres_i, Di, Ii))
if max_results is not None and totres > max_results:
LOG.info('too many results %d > %d, scaling back radius' %
(totres, max_results))
radius, totres = apply_maxres(
res_batches, min_results,
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
)
t2 = time.time()
t_search += t1 - t0
t_post_process += t2 - t1
LOG.debug(' [%.3f s] %d queries done, %d results' % (
time.time() - t_start, qtot, totres))
LOG.info(
'search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
t_search, t_post_process, totres, radius)
)
if clip_to_min and totres > min_results:
radius, totres = apply_maxres(
res_batches, min_results,
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
)
nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])
dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches])
ids = np.hstack([ids_i for nres_i, dis_i, ids_i in res_batches])
lims = np.zeros(len(nres) + 1, dtype='uint64')
lims[1:] = np.cumsum(nres)
return radius, lims, dis, ids
def exponential_query_iterator(xq, start_bs=32, max_bs=20000):
""" produces batches of progressively increasing sizes. This is useful to
adjust the search radius progressively without overflowing with
intermediate results """
nq = len(xq)
bs = start_bs
i = 0
while i < nq:
xqi = xq[i:i + bs]
yield xqi
if bs < max_bs:
bs *= 2
i += len(xqi)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import faiss
import re
def get_code_size(d, indexkey):
""" size of one vector in an index in dimension d
constructed with factory string indexkey"""
if indexkey == "Flat":
return d * 4
if indexkey.endswith(",RFlat"):
return d * 4 + get_code_size(d, indexkey[:-len(",RFlat")])
mo = re.match("IVF\\d+(_HNSW32)?,(.*)$", indexkey)
if mo:
return get_code_size(d, mo.group(2))
mo = re.match("IVF\\d+\\(.*\\)?,(.*)$", indexkey)
if mo:
return get_code_size(d, mo.group(1))
mo = re.match("IMI\\d+x2,(.*)$", indexkey)
if mo:
return get_code_size(d, mo.group(1))
mo = re.match("(.*),Refine\\((.*)\\)$", indexkey)
if mo:
return get_code_size(d, mo.group(1)) + get_code_size(d, mo.group(2))
mo = re.match('PQ(\\d+)x(\\d+)(fs|fsr)?$', indexkey)
if mo:
return (int(mo.group(1)) * int(mo.group(2)) + 7) // 8
mo = re.match('PQ(\\d+)\\+(\\d+)$', indexkey)
if mo:
return (int(mo.group(1)) + int(mo.group(2)))
mo = re.match('PQ(\\d+)$', indexkey)
if mo:
return int(mo.group(1))
if indexkey == "HNSW32" or indexkey == "HNSW32,Flat":
return d * 4 + 64 * 4 # roughly
if indexkey == 'SQ8':
return d
elif indexkey == 'SQ4':
return (d + 1) // 2
elif indexkey == 'SQ6':
return (d * 6 + 7) // 8
elif indexkey == 'SQfp16':
return d * 2
mo = re.match('PCAR?(\\d+),(.*)$', indexkey)
if mo:
return get_code_size(int(mo.group(1)), mo.group(2))
mo = re.match('OPQ\\d+_(\\d+),(.*)$', indexkey)
if mo:
return get_code_size(int(mo.group(1)), mo.group(2))
mo = re.match('OPQ\\d+,(.*)$', indexkey)
if mo:
return get_code_size(d, mo.group(1))
mo = re.match('RR(\\d+),(.*)$', indexkey)
if mo:
return get_code_size(int(mo.group(1)), mo.group(2))
raise RuntimeError("cannot parse " + indexkey)
def reverse_index_factory(index):
"""
attempts to get the factory string the index was built with
"""
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexFlat):
return "Flat"
if isinstance(index, faiss.IndexIVF):
quantizer = faiss.downcast_index(index.quantizer)
if isinstance(quantizer, faiss.IndexFlat):
prefix = "IVF%d" % index.nlist
elif isinstance(quantizer, faiss.MultiIndexQuantizer):
prefix = "IMI%dx%d" % (quantizer.pq.M, quantizer.pq.nbit)
elif isinstance(quantizer, faiss.IndexHNSW):
prefix = "IVF%d_HNSW%d" % (index.nlist, quantizer.hnsw.M)
else:
prefix = "IVF%d(%s)" % (index.nlist, reverse_index_factory(quantizer))
if isinstance(index, faiss.IndexIVFFlat):
return prefix + ",Flat"
if isinstance(index, faiss.IndexIVFScalarQuantizer):
return prefix + ",SQ8"
raise NotImplementedError()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import faiss
def get_invlist(invlists, l):
""" returns the inverted lists content as a pair of (list_ids, list_codes).
The codes are reshaped to a proper size
"""
invlists = faiss.downcast_InvertedLists(invlists)
ls = invlists.list_size(l)
list_ids = np.zeros(ls, dtype='int64')
ids = codes = None
try:
ids = invlists.get_ids(l)
if ls > 0:
faiss.memcpy(faiss.swig_ptr(list_ids), ids, list_ids.nbytes)
codes = invlists.get_codes(l)
if invlists.code_size != faiss.InvertedLists.INVALID_CODE_SIZE:
list_codes = np.zeros((ls, invlists.code_size), dtype='uint8')
else:
# it's a BlockInvertedLists
npb = invlists.n_per_block
bs = invlists.block_size
ls_round = (ls + npb - 1) // npb
list_codes = np.zeros((ls_round, bs // npb, npb), dtype='uint8')
if ls > 0:
faiss.memcpy(faiss.swig_ptr(list_codes), codes, list_codes.nbytes)
finally:
if ids is not None:
invlists.release_ids(l, ids)
if codes is not None:
invlists.release_codes(l, codes)
return list_ids, list_codes
def get_invlist_sizes(invlists):
""" return the array of sizes of the inverted lists """
return np.array([
invlists.list_size(i)
for i in range(invlists.nlist)
], dtype='int64')
def print_object_fields(obj):
""" list values all fields of an object known to SWIG """
for name in obj.__class__.__swig_getmethods__:
print(f"{name} = {getattr(obj, name)}")
def get_pq_centroids(pq):
""" return the PQ centroids as an array """
cen = faiss.vector_to_array(pq.centroids)
return cen.reshape(pq.M, pq.ksub, pq.dsub)
def get_LinearTransform_matrix(pca):
""" extract matrix + bias from the PCA object
works for any linear transform (OPQ, random rotation, etc.)
"""
b = faiss.vector_to_array(pca.b)
A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in)
return A, b
def get_additive_quantizer_codebooks(aq):
""" return to codebooks of an additive quantizer """
codebooks = faiss.vector_to_array(aq.codebooks).reshape(-1, aq.d)
co = faiss.vector_to_array(aq.codebook_offsets)
return [
codebooks[co[i]:co[i + 1]]
for i in range(aq.M)
]
def get_flat_data(index):
""" copy and return the data matrix in an IndexFlat """
xb = faiss.vector_to_array(index.codes).view("float32")
return xb.reshape(index.ntotal, index.d)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import faiss
def add_preassigned(index_ivf, x, a, ids=None):
"""
Add elements to an IVF index, where the assignment is already computed
"""
n, d = x.shape
assert a.shape == (n, )
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
assert d == index_ivf.d
if ids is not None:
assert ids.shape == (n, )
ids = faiss.swig_ptr(ids)
index_ivf.add_core(
n, faiss.swig_ptr(x), ids, faiss.swig_ptr(a)
)
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
"""
Perform a search in the IVF index, with predefined lists to search into
"""
n, d = xq.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
dis_type = "int32"
else:
dis_type = "float32"
assert d == index_ivf.d
assert list_nos.shape == (n, index_ivf.nprobe)
# the coarse distances are used in IVFPQ with L2 distance and by_residual=True
# otherwise we provide dummy coarse_dis
if coarse_dis is None:
coarse_dis = np.zeros((n, index_ivf.nprobe), dtype=dis_type)
else:
assert coarse_dis.shape == (n, index_ivf.nprobe)
D = np.empty((n, k), dtype=dis_type)
I = np.empty((n, k), dtype='int64')
sp = faiss.swig_ptr
index_ivf.search_preassigned(
n, sp(xq), k,
sp(list_nos), sp(coarse_dis), sp(D), sp(I), False)
return D, I
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
"""
Perform a range search in the IVF index, with predefined lists to search into
"""
n, d = x.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
dis_type = "int32"
else:
dis_type = "float32"
# the coarse distances are used in IVFPQ with L2 distance and by_residual=True
# otherwise we provide dummy coarse_dis
if coarse_dis is None:
coarse_dis = np.empty((n, index_ivf.nprobe), dtype=dis_type)
else:
assert coarse_dis.shape == (n, index_ivf.nprobe)
assert d == index_ivf.d
assert list_nos.shape == (n, index_ivf.nprobe)
res = faiss.RangeSearchResult(n)
sp = faiss.swig_ptr
index_ivf.range_search_preassigned(
n, sp(x), radius,
sp(list_nos), sp(coarse_dis),
res
)
# get pointers and copy them
lims = faiss.rev_swig_ptr(res.lims, n + 1).copy()
num_results = int(lims[-1])
dist = faiss.rev_swig_ptr(res.distances, num_results).copy()
indices = faiss.rev_swig_ptr(res.labels, num_results).copy()
return lims, dist, indices
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import faiss
import logging
LOG = logging.getLogger(__name__)
def merge_ondisk(trained_index: faiss.Index,
shard_fnames: List[str],
ivfdata_fname: str) -> None:
""" Add the contents of the indexes stored in shard_fnames into the index
trained_index. The on-disk data is stored in ivfdata_fname """
# merge the images into an on-disk index
# first load the inverted lists
ivfs = []
for fname in shard_fnames:
# the IO_FLAG_MMAP is to avoid actually loading the data thus
# the total size of the inverted lists can exceed the
# available RAM
LOG.info("read " + fname)
index = faiss.read_index(fname, faiss.IO_FLAG_MMAP)
index_ivf = faiss.extract_index_ivf(index)
ivfs.append(index_ivf.invlists)
# avoid that the invlists get deallocated with the index
index_ivf.own_invlists = False
# construct the output index
index = trained_index
index_ivf = faiss.extract_index_ivf(index)
assert index.ntotal == 0, "works only on empty index"
# prepare the output inverted lists. They will be written
# to merged_index.ivfdata
invlists = faiss.OnDiskInvertedLists(
index_ivf.nlist, index_ivf.code_size,
ivfdata_fname)
# merge all the inverted lists
ivf_vector = faiss.InvertedListsPtrVector()
for ivf in ivfs:
ivf_vector.push_back(ivf)
LOG.info("merge %d inverted lists " % ivf_vector.size())
ntotal = invlists.merge_from(ivf_vector.data(), ivf_vector.size())
# now replace the inverted lists in the output index
index.ntotal = index_ivf.ntotal = ntotal
index_ivf.replace_invlists(invlists, True)
invlists.this.disown()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment