Commit 395d2ce6 authored by huchen's avatar huchen
Browse files

init the faiss for rocm

parent 5ded39f5
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/IndexLSH.h>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/utils.h>
namespace faiss {
/***************************************************************
* IndexLSH
***************************************************************/
IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
: IndexFlatCodes((nbits + 7) / 8, d),
nbits(nbits),
rotate_data(rotate_data),
train_thresholds(train_thresholds),
rrot(d, nbits) {
is_trained = !train_thresholds;
if (rotate_data) {
rrot.init(5);
} else {
FAISS_THROW_IF_NOT(d >= nbits);
}
}
IndexLSH::IndexLSH() : nbits(0), rotate_data(false), train_thresholds(false) {}
const float* IndexLSH::apply_preprocess(idx_t n, const float* x) const {
float* xt = nullptr;
if (rotate_data) {
// also applies bias if exists
xt = rrot.apply(n, x);
} else if (d != nbits) {
assert(nbits < d);
xt = new float[nbits * n];
float* xp = xt;
for (idx_t i = 0; i < n; i++) {
const float* xl = x + i * d;
for (int j = 0; j < nbits; j++)
*xp++ = xl[j];
}
}
if (train_thresholds) {
if (xt == NULL) {
xt = new float[nbits * n];
memcpy(xt, x, sizeof(*x) * n * nbits);
}
float* xp = xt;
for (idx_t i = 0; i < n; i++)
for (int j = 0; j < nbits; j++)
*xp++ -= thresholds[j];
}
return xt ? xt : x;
}
void IndexLSH::train(idx_t n, const float* x) {
if (train_thresholds) {
thresholds.resize(nbits);
train_thresholds = false;
const float* xt = apply_preprocess(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
train_thresholds = true;
float* transposed_x = new float[n * nbits];
ScopeDeleter<float> del2(transposed_x);
for (idx_t i = 0; i < n; i++)
for (idx_t j = 0; j < nbits; j++)
transposed_x[j * n + i] = xt[i * nbits + j];
for (idx_t i = 0; i < nbits; i++) {
float* xi = transposed_x + i * n;
// std::nth_element
std::sort(xi, xi + n);
if (n % 2 == 1)
thresholds[i] = xi[n / 2];
else
thresholds[i] = (xi[n / 2 - 1] + xi[n / 2]) / 2;
}
}
is_trained = true;
}
void IndexLSH::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_preprocess(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
uint8_t* qcodes = new uint8_t[n * code_size];
ScopeDeleter<uint8_t> del2(qcodes);
fvecs2bitvecs(xt, qcodes, nbits, n);
int* idistances = new int[n * k];
ScopeDeleter<int> del3(idistances);
int_maxheap_array_t res = {size_t(n), size_t(k), labels, idistances};
hammings_knn_hc(&res, qcodes, codes.data(), ntotal, code_size, true);
// convert distances to floats
for (int i = 0; i < k * n; i++)
distances[i] = idistances[i];
}
void IndexLSH::transfer_thresholds(LinearTransform* vt) {
if (!train_thresholds)
return;
FAISS_THROW_IF_NOT(nbits == vt->d_out);
if (!vt->have_bias) {
vt->b.resize(nbits, 0);
vt->have_bias = true;
}
for (int i = 0; i < nbits; i++)
vt->b[i] -= thresholds[i];
train_thresholds = false;
thresholds.clear();
}
void IndexLSH::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_preprocess(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
fvecs2bitvecs(xt, bytes, nbits, n);
}
void IndexLSH::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
float* xt = x;
ScopeDeleter<float> del;
if (rotate_data || nbits != d) {
xt = new float[n * nbits];
del.set(xt);
}
bitvecs2fvecs(bytes, xt, nbits, n);
if (train_thresholds) {
float* xp = xt;
for (idx_t i = 0; i < n; i++) {
for (int j = 0; j < nbits; j++) {
*xp++ += thresholds[j];
}
}
}
if (rotate_data) {
rrot.reverse_transform(n, xt, x);
} else if (nbits != d) {
for (idx_t i = 0; i < n; i++) {
memcpy(x + i * d, xt + i * nbits, nbits * sizeof(xt[0]));
}
}
}
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#ifndef INDEX_LSH_H
#define INDEX_LSH_H
#include <vector>
#include <faiss/IndexFlatCodes.h>
#include <faiss/VectorTransform.h>
namespace faiss {
/** The sign of each vector component is put in a binary signature */
struct IndexLSH : IndexFlatCodes {
int nbits; ///< nb of bits per vector
bool rotate_data; ///< whether to apply a random rotation to input
bool train_thresholds; ///< whether we train thresholds or use 0
RandomRotationMatrix rrot; ///< optional random rotation
std::vector<float> thresholds; ///< thresholds to compare with
IndexLSH(
idx_t d,
int nbits,
bool rotate_data = true,
bool train_thresholds = false);
/** Preprocesses and resizes the input to the size required to
* binarize the data
*
* @param x input vectors, size n * d
* @return output vectors, size n * bits. May be the same pointer
* as x, otherwise it should be deleted by the caller
*/
const float* apply_preprocess(idx_t n, const float* x) const;
void train(idx_t n, const float* x) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
/// transfer the thresholds to a pre-processing stage (and unset
/// train_thresholds)
void transfer_thresholds(LinearTransform* vt);
~IndexLSH() override {}
IndexLSH();
/* standalone codec interface.
*
* The vectors are decoded to +/- 1 (not 0, 1) */
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
};
} // namespace faiss
#endif
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexLattice.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/hamming.h> // for the bitstring routines
namespace faiss {
IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2)
: Index(d),
nsq(nsq),
dsq(d / nsq),
zn_sphere_codec(dsq, r2),
scale_nbit(scale_nbit) {
FAISS_THROW_IF_NOT(d % nsq == 0);
lattice_nbit = 0;
while (!(((uint64_t)1 << lattice_nbit) >= zn_sphere_codec.nv)) {
lattice_nbit++;
}
int total_nbit = (lattice_nbit + scale_nbit) * nsq;
code_size = (total_nbit + 7) / 8;
is_trained = false;
}
void IndexLattice::train(idx_t n, const float* x) {
// compute ranges per sub-block
trained.resize(nsq * 2);
float* mins = trained.data();
float* maxs = trained.data() + nsq;
for (int sq = 0; sq < nsq; sq++) {
mins[sq] = HUGE_VAL;
maxs[sq] = -1;
}
for (idx_t i = 0; i < n; i++) {
for (int sq = 0; sq < nsq; sq++) {
float norm2 = fvec_norm_L2sqr(x + i * d + sq * dsq, dsq);
if (norm2 > maxs[sq])
maxs[sq] = norm2;
if (norm2 < mins[sq])
mins[sq] = norm2;
}
}
for (int sq = 0; sq < nsq; sq++) {
mins[sq] = sqrtf(mins[sq]);
maxs[sq] = sqrtf(maxs[sq]);
}
is_trained = true;
}
/* The standalone codec interface */
size_t IndexLattice::sa_code_size() const {
return code_size;
}
void IndexLattice::sa_encode(idx_t n, const float* x, uint8_t* codes) const {
const float* mins = trained.data();
const float* maxs = mins + nsq;
int64_t sc = int64_t(1) << scale_nbit;
#pragma omp parallel for
for (idx_t i = 0; i < n; i++) {
BitstringWriter wr(codes + i * code_size, code_size);
const float* xi = x + i * d;
for (int j = 0; j < nsq; j++) {
float nj = (sqrtf(fvec_norm_L2sqr(xi, dsq)) - mins[j]) * sc /
(maxs[j] - mins[j]);
if (nj < 0)
nj = 0;
if (nj >= sc)
nj = sc - 1;
wr.write((int64_t)nj, scale_nbit);
wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
xi += dsq;
}
}
}
void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
const float* mins = trained.data();
const float* maxs = mins + nsq;
float sc = int64_t(1) << scale_nbit;
float r = sqrtf(zn_sphere_codec.r2);
#pragma omp parallel for
for (idx_t i = 0; i < n; i++) {
BitstringReader rd(codes + i * code_size, code_size);
float* xi = x + i * d;
for (int j = 0; j < nsq; j++) {
float norm =
(rd.read(scale_nbit) + 0.5) * (maxs[j] - mins[j]) / sc +
mins[j];
norm /= r;
zn_sphere_codec.decode(rd.read(lattice_nbit), xi);
for (int l = 0; l < dsq; l++) {
xi[l] *= norm;
}
xi += dsq;
}
}
}
void IndexLattice::add(idx_t, const float*) {
FAISS_THROW_MSG("not implemented");
}
void IndexLattice::search(idx_t, const float*, idx_t, float*, idx_t*) const {
FAISS_THROW_MSG("not implemented");
}
void IndexLattice::reset() {
FAISS_THROW_MSG("not implemented");
}
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#ifndef FAISS_INDEX_LATTICE_H
#define FAISS_INDEX_LATTICE_H
#include <vector>
#include <faiss/IndexIVF.h>
#include <faiss/impl/lattice_Zn.h>
namespace faiss {
/** Index that encodes a vector with a series of Zn lattice quantizers
*/
struct IndexLattice : Index {
/// number of sub-vectors
int nsq;
/// dimension of sub-vectors
size_t dsq;
/// the lattice quantizer
ZnSphereCodecAlt zn_sphere_codec;
/// nb bits used to encode the scale, per subvector
int scale_nbit, lattice_nbit;
/// total, in bytes
size_t code_size;
/// mins and maxes of the vector norms, per subquantizer
std::vector<float> trained;
IndexLattice(idx_t d, int nsq, int scale_nbit, int r2);
void train(idx_t n, const float* x) override;
/* The standalone codec interface */
size_t sa_code_size() const override;
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
/// not implemented
void add(idx_t n, const float* x) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
void reset() override;
};
} // namespace faiss
#endif
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexNNDescent.h>
#include <omp.h>
#include <cinttypes>
#include <cstdio>
#include <cstdlib>
#include <queue>
#include <unordered_set>
#ifdef __SSE__
#endif
#include <faiss/IndexFlat.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/random.h>
extern "C" {
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
int sgemm_(
const char* transa,
const char* transb,
FINTEGER* m,
FINTEGER* n,
FINTEGER* k,
const float* alpha,
const float* a,
FINTEGER* lda,
const float* b,
FINTEGER* ldb,
float* beta,
float* c,
FINTEGER* ldc);
}
namespace faiss {
using idx_t = Index::idx_t;
using storage_idx_t = NNDescent::storage_idx_t;
/**************************************************************
* add / search blocks of descriptors
**************************************************************/
namespace {
/* Wrap the distance computer into one that negates the
distances. This makes supporting INNER_PRODUCE search easier */
struct NegativeDistanceComputer : DistanceComputer {
/// owned by this
DistanceComputer* basedis;
explicit NegativeDistanceComputer(DistanceComputer* basedis)
: basedis(basedis) {}
void set_query(const float* x) override {
basedis->set_query(x);
}
/// compute distance of vector i to current query
float operator()(idx_t i) override {
return -(*basedis)(i);
}
/// compute distance between two stored vectors
float symmetric_dis(idx_t i, idx_t j) override {
return -basedis->symmetric_dis(i, j);
}
~NegativeDistanceComputer() override {
delete basedis;
}
};
DistanceComputer* storage_distance_computer(const Index* storage) {
if (storage->metric_type == METRIC_INNER_PRODUCT) {
return new NegativeDistanceComputer(storage->get_distance_computer());
} else {
return storage->get_distance_computer();
}
}
} // namespace
/**************************************************************
* IndexNNDescent implementation
**************************************************************/
IndexNNDescent::IndexNNDescent(int d, int K, MetricType metric)
: Index(d, metric),
nndescent(d, K),
own_fields(false),
storage(nullptr) {}
IndexNNDescent::IndexNNDescent(Index* storage, int K)
: Index(storage->d, storage->metric_type),
nndescent(storage->d, K),
own_fields(false),
storage(storage) {}
IndexNNDescent::~IndexNNDescent() {
if (own_fields) {
delete storage;
}
}
void IndexNNDescent::train(idx_t n, const float* x) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNNDescentFlat (or variants) "
"instead of IndexNNDescent directly");
// nndescent structure does not require training
storage->train(n, x);
is_trained = true;
}
void IndexNNDescent::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const
{
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNNDescentFlat (or variants) "
"instead of IndexNNDescent directly");
if (verbose) {
printf("Parameters: k=%" PRId64 ", search_L=%d\n",
k,
nndescent.search_L);
}
idx_t check_period =
InterruptCallback::get_period_hint(d * nndescent.search_L);
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);
#pragma omp parallel
{
VisitedTable vt(ntotal);
DistanceComputer* dis = storage_distance_computer(storage);
ScopeDeleter1<DistanceComputer> del(dis);
#pragma omp for
for (idx_t i = i0; i < i1; i++) {
idx_t* idxi = labels + i * k;
float* simi = distances + i * k;
dis->set_query(x + i * d);
nndescent.search(*dis, k, idxi, simi, vt);
}
}
InterruptCallback::check();
}
if (metric_type == METRIC_INNER_PRODUCT) {
// we need to revert the negated distances
for (size_t i = 0; i < k * n; i++) {
distances[i] = -distances[i];
}
}
}
void IndexNNDescent::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNNDescentFlat (or variants) "
"instead of IndexNNDescent directly");
FAISS_THROW_IF_NOT(is_trained);
if (ntotal != 0) {
fprintf(stderr,
"WARNING NNDescent doest not support dynamic insertions,"
"multiple insertions would lead to re-building the index");
}
storage->add(n, x);
ntotal = storage->ntotal;
DistanceComputer* dis = storage_distance_computer(storage);
ScopeDeleter1<DistanceComputer> del(dis);
nndescent.build(*dis, ntotal, verbose);
}
void IndexNNDescent::reset() {
nndescent.reset();
storage->reset();
ntotal = 0;
}
void IndexNNDescent::reconstruct(idx_t key, float* recons) const {
storage->reconstruct(key, recons);
}
/**************************************************************
* IndexNNDescentFlat implementation
**************************************************************/
IndexNNDescentFlat::IndexNNDescentFlat() {
is_trained = true;
}
IndexNNDescentFlat::IndexNNDescentFlat(int d, int M, MetricType metric)
: IndexNNDescent(new IndexFlat(d, metric), M) {
own_fields = true;
is_trained = true;
}
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#pragma once
#include <vector>
#include <faiss/IndexFlat.h>
#include <faiss/impl/NNDescent.h>
#include <faiss/utils/utils.h>
namespace faiss {
/** The NNDescent index is a normal random-access index with an NNDescent
* link structure built on top */
struct IndexNNDescent : Index {
// internal storage of vectors (32 bits)
using storage_idx_t = NNDescent::storage_idx_t;
/// Faiss results are 64-bit
using idx_t = Index::idx_t;
// the link strcuture
NNDescent nndescent;
// the sequential storage
bool own_fields;
Index* storage;
explicit IndexNNDescent(
int d = 0,
int K = 32,
MetricType metric = METRIC_L2);
explicit IndexNNDescent(Index* storage, int K = 32);
~IndexNNDescent() override;
void add(idx_t n, const float* x) override;
/// Trains the storage if needed
void train(idx_t n, const float* x) override;
/// entry point for search
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
void reconstruct(idx_t key, float* recons) const override;
void reset() override;
};
/** Flat index topped with with a NNDescent structure to access elements
* more efficiently.
*/
struct IndexNNDescentFlat : IndexNNDescent {
IndexNNDescentFlat();
IndexNNDescentFlat(int d, int K, MetricType metric = METRIC_L2);
};
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexNSG.h>
#include <omp.h>
#include <cinttypes>
#include <memory>
#include <faiss/IndexFlat.h>
#include <faiss/IndexNNDescent.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/distances.h>
namespace faiss {
using idx_t = Index::idx_t;
using namespace nsg;
/**************************************************************
* IndexNSG implementation
**************************************************************/
IndexNSG::IndexNSG(int d, int R, MetricType metric)
: Index(d, metric),
nsg(R),
own_fields(false),
storage(nullptr),
is_built(false),
GK(64),
build_type(0) {
nndescent_S = 10;
nndescent_R = 100;
nndescent_L = GK + 50;
nndescent_iter = 10;
}
IndexNSG::IndexNSG(Index* storage, int R)
: Index(storage->d, storage->metric_type),
nsg(R),
own_fields(false),
storage(storage),
is_built(false),
GK(64),
build_type(1) {
nndescent_S = 10;
nndescent_R = 100;
nndescent_L = GK + 50;
nndescent_iter = 10;
}
IndexNSG::~IndexNSG() {
if (own_fields) {
delete storage;
}
}
void IndexNSG::train(idx_t n, const float* x) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNSGFlat (or variants) instead of IndexNSG directly");
// nsg structure does not require training
storage->train(n, x);
is_trained = true;
}
void IndexNSG::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const
{
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNSGFlat (or variants) instead of IndexNSG directly");
int L = std::max(nsg.search_L, (int)k); // in case of search L = -1
idx_t check_period = InterruptCallback::get_period_hint(d * L);
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);
#pragma omp parallel
{
VisitedTable vt(ntotal);
DistanceComputer* dis = storage_distance_computer(storage);
ScopeDeleter1<DistanceComputer> del(dis);
#pragma omp for
for (idx_t i = i0; i < i1; i++) {
idx_t* idxi = labels + i * k;
float* simi = distances + i * k;
dis->set_query(x + i * d);
nsg.search(*dis, k, idxi, simi, vt);
vt.advance();
}
}
InterruptCallback::check();
}
if (metric_type == METRIC_INNER_PRODUCT) {
// we need to revert the negated distances
for (size_t i = 0; i < k * n; i++) {
distances[i] = -distances[i];
}
}
}
void IndexNSG::build(idx_t n, const float* x, idx_t* knn_graph, int GK) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNSGFlat (or variants) instead of IndexNSG directly");
FAISS_THROW_IF_NOT_MSG(
!is_built && ntotal == 0, "The IndexNSG is already built");
storage->add(n, x);
ntotal = storage->ntotal;
// check the knn graph
check_knn_graph(knn_graph, n, GK);
const nsg::Graph<idx_t> knng(knn_graph, n, GK);
nsg.build(storage, n, knng, verbose);
is_built = true;
}
void IndexNSG::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNSGFlat (or variants) "
"instead of IndexNSG directly");
FAISS_THROW_IF_NOT(is_trained);
FAISS_THROW_IF_NOT_MSG(
!is_built && ntotal == 0,
"NSG does not support incremental addition");
std::vector<idx_t> knng;
if (verbose) {
printf("IndexNSG::add %zd vectors\n", size_t(n));
}
if (build_type == 0) { // build with brute force search
if (verbose) {
printf(" Build knn graph with brute force search on storage index\n");
}
storage->add(n, x);
ntotal = storage->ntotal;
FAISS_THROW_IF_NOT(ntotal == n);
knng.resize(ntotal * (GK + 1));
storage->assign(ntotal, x, knng.data(), GK + 1);
// Remove itself
// - For metric distance, we just need to remove the first neighbor
// - But for non-metric, e.g. inner product, we need to check
// - each neighbor
if (storage->metric_type == METRIC_INNER_PRODUCT) {
for (idx_t i = 0; i < ntotal; i++) {
int count = 0;
for (int j = 0; j < GK + 1; j++) {
idx_t id = knng[i * (GK + 1) + j];
if (id != i) {
knng[i * GK + count] = id;
count += 1;
}
if (count == GK) {
break;
}
}
}
} else {
for (idx_t i = 0; i < ntotal; i++) {
memmove(knng.data() + i * GK,
knng.data() + i * (GK + 1) + 1,
GK * sizeof(idx_t));
}
}
} else if (build_type == 1) { // build with NNDescent
IndexNNDescent index(storage, GK);
index.nndescent.S = nndescent_S;
index.nndescent.R = nndescent_R;
index.nndescent.L = std::max(nndescent_L, GK + 50);
index.nndescent.iter = nndescent_iter;
index.verbose = verbose;
if (verbose) {
printf(" Build knn graph with NNdescent S=%d R=%d L=%d niter=%d\n",
index.nndescent.S,
index.nndescent.R,
index.nndescent.L,
index.nndescent.iter);
}
// prevent IndexNSG from deleting the storage
index.own_fields = false;
index.add(n, x);
// storage->add is already implicit called in IndexNSG.add
ntotal = storage->ntotal;
FAISS_THROW_IF_NOT(ntotal == n);
knng.resize(ntotal * GK);
// cast from idx_t to int
const int* knn_graph = index.nndescent.final_graph.data();
#pragma omp parallel for
for (idx_t i = 0; i < ntotal * GK; i++) {
knng[i] = knn_graph[i];
}
} else {
FAISS_THROW_MSG("build_type should be 0 or 1");
}
if (verbose) {
printf(" Check the knn graph\n");
}
// check the knn graph
check_knn_graph(knng.data(), n, GK);
if (verbose) {
printf(" nsg building\n");
}
const nsg::Graph<idx_t> knn_graph(knng.data(), n, GK);
nsg.build(storage, n, knn_graph, verbose);
is_built = true;
}
void IndexNSG::reset() {
nsg.reset();
storage->reset();
ntotal = 0;
is_built = false;
}
void IndexNSG::reconstruct(idx_t key, float* recons) const {
storage->reconstruct(key, recons);
}
void IndexNSG::check_knn_graph(const idx_t* knn_graph, idx_t n, int K) const {
idx_t total_count = 0;
#pragma omp parallel for reduction(+ : total_count)
for (idx_t i = 0; i < n; i++) {
int count = 0;
for (int j = 0; j < K; j++) {
idx_t id = knn_graph[i * K + j];
if (id < 0 || id >= n || id == i) {
count += 1;
}
}
total_count += count;
}
if (total_count > 0) {
fprintf(stderr,
"WARNING: the input knn graph "
"has %" PRId64 " invalid entries\n",
total_count);
}
FAISS_THROW_IF_NOT_MSG(
total_count < n / 10,
"There are too much invalid entries in the knn graph. "
"It may be an invalid knn graph.");
}
/**************************************************************
* IndexNSGFlat implementation
**************************************************************/
IndexNSGFlat::IndexNSGFlat() {
is_trained = true;
}
IndexNSGFlat::IndexNSGFlat(int d, int R, MetricType metric)
: IndexNSG(new IndexFlat(d, metric), R) {
own_fields = true;
is_trained = true;
}
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#pragma once
#include <vector>
#include <faiss/IndexFlat.h>
#include <faiss/IndexNNDescent.h>
#include <faiss/impl/NSG.h>
#include <faiss/utils/utils.h>
namespace faiss {
/** The NSG index is a normal random-access index with a NSG
* link structure built on top */
struct IndexNSG : Index {
/// the link strcuture
NSG nsg;
/// the sequential storage
bool own_fields;
Index* storage;
/// the index is built or not
bool is_built;
/// K of KNN graph for building
int GK;
/// indicate how to build a knn graph
/// - 0: build NSG with brute force search
/// - 1: build NSG with NNDescent
char build_type;
/// parameters for nndescent
int nndescent_S;
int nndescent_R;
int nndescent_L;
int nndescent_iter;
explicit IndexNSG(int d = 0, int R = 32, MetricType metric = METRIC_L2);
explicit IndexNSG(Index* storage, int R = 32);
~IndexNSG() override;
void build(idx_t n, const float* x, idx_t* knn_graph, int GK);
void add(idx_t n, const float* x) override;
/// Trains the storage if needed
void train(idx_t n, const float* x) override;
/// entry point for search
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
void reconstruct(idx_t key, float* recons) const override;
void reset() override;
void check_knn_graph(const idx_t* knn_graph, idx_t n, int K) const;
};
/** Flat index topped with with a NSG structure to access elements
* more efficiently.
*/
struct IndexNSGFlat : IndexNSG {
IndexNSGFlat();
IndexNSGFlat(int d, int R, MetricType metric = METRIC_L2);
};
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexPQ.h>
#include <cinttypes>
#include <cmath>
#include <cstddef>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/hamming.h>
namespace faiss {
/*********************************************************
* IndexPQ implementation
********************************************************/
IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
: IndexFlatCodes(0, d, metric), pq(d, M, nbits) {
is_trained = false;
do_polysemous_training = false;
polysemous_ht = nbits * M + 1;
search_type = ST_PQ;
encode_signs = false;
code_size = pq.code_size;
}
IndexPQ::IndexPQ() {
metric_type = METRIC_L2;
is_trained = false;
do_polysemous_training = false;
polysemous_ht = pq.nbits * pq.M + 1;
search_type = ST_PQ;
encode_signs = false;
}
void IndexPQ::train(idx_t n, const float* x) {
if (!do_polysemous_training) { // standard training
pq.train(n, x);
} else {
idx_t ntrain_perm = polysemous_training.ntrain_permutation;
if (ntrain_perm > n / 4)
ntrain_perm = n / 4;
if (verbose) {
printf("PQ training on %" PRId64 " points, remains %" PRId64
" points: "
"training polysemous on %s\n",
n - ntrain_perm,
ntrain_perm,
ntrain_perm == 0 ? "centroids" : "these");
}
pq.train(n - ntrain_perm, x);
polysemous_training.optimize_pq_for_hamming(
pq, ntrain_perm, x + (n - ntrain_perm) * d);
}
is_trained = true;
}
namespace {
template <class PQDecoder>
struct PQDistanceComputer : DistanceComputer {
size_t d;
MetricType metric;
Index::idx_t nb;
const uint8_t* codes;
size_t code_size;
const ProductQuantizer& pq;
const float* sdc;
std::vector<float> precomputed_table;
size_t ndis;
float operator()(idx_t i) override {
const uint8_t* code = codes + i * code_size;
const float* dt = precomputed_table.data();
PQDecoder decoder(code, pq.nbits);
float accu = 0;
for (int j = 0; j < pq.M; j++) {
accu += dt[decoder.decode()];
dt += 1 << decoder.nbits;
}
ndis++;
return accu;
}
float symmetric_dis(idx_t i, idx_t j) override {
FAISS_THROW_IF_NOT(sdc);
const float* sdci = sdc;
float accu = 0;
PQDecoder codei(codes + i * code_size, pq.nbits);
PQDecoder codej(codes + j * code_size, pq.nbits);
for (int l = 0; l < pq.M; l++) {
accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
sdci += uint64_t(1) << (2 * codei.nbits);
}
ndis++;
return accu;
}
explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
precomputed_table.resize(pq.M * pq.ksub);
nb = storage.ntotal;
d = storage.d;
metric = storage.metric_type;
codes = storage.codes.data();
code_size = pq.code_size;
if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
sdc = pq.sdc_table.data();
} else {
sdc = nullptr;
}
ndis = 0;
}
void set_query(const float* x) override {
if (metric == METRIC_L2) {
pq.compute_distance_table(x, precomputed_table.data());
} else {
pq.compute_inner_prod_table(x, precomputed_table.data());
}
}
};
} // namespace
DistanceComputer* IndexPQ::get_distance_computer() const {
if (pq.nbits == 8) {
return new PQDistanceComputer<PQDecoder8>(*this);
} else if (pq.nbits == 16) {
return new PQDistanceComputer<PQDecoder16>(*this);
} else {
return new PQDistanceComputer<PQDecoderGeneric>(*this);
}
}
/*****************************************
* IndexPQ polysemous search routines
******************************************/
void IndexPQ::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);
if (search_type == ST_PQ) { // Simple PQ search
if (metric_type == METRIC_L2) {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};
pq.search(x, n, codes.data(), ntotal, &res, true);
} else {
float_minheap_array_t res = {
size_t(n), size_t(k), labels, distances};
pq.search_ip(x, n, codes.data(), ntotal, &res, true);
}
indexPQ_stats.nq += n;
indexPQ_stats.ncode += n * ntotal;
} else if (
search_type == ST_polysemous ||
search_type == ST_polysemous_generalize) {
FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
search_core_polysemous(n, x, k, distances, labels);
} else { // code-to-code distances
uint8_t* q_codes = new uint8_t[n * pq.code_size];
ScopeDeleter<uint8_t> del(q_codes);
if (!encode_signs) {
pq.compute_codes(x, q_codes, n);
} else {
FAISS_THROW_IF_NOT(d == pq.nbits * pq.M);
memset(q_codes, 0, n * pq.code_size);
for (size_t i = 0; i < n; i++) {
const float* xi = x + i * d;
uint8_t* code = q_codes + i * pq.code_size;
for (int j = 0; j < d; j++)
if (xi[j] > 0)
code[j >> 3] |= 1 << (j & 7);
}
}
if (search_type == ST_SDC) {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};
pq.search_sdc(q_codes, n, codes.data(), ntotal, &res, true);
} else {
int* idistances = new int[n * k];
ScopeDeleter<int> del(idistances);
int_maxheap_array_t res = {
size_t(n), size_t(k), labels, idistances};
if (search_type == ST_HE) {
hammings_knn_hc(
&res,
q_codes,
codes.data(),
ntotal,
pq.code_size,
true);
} else if (search_type == ST_generalized_HE) {
generalized_hammings_knn_hc(
&res,
q_codes,
codes.data(),
ntotal,
pq.code_size,
true);
}
// convert distances to floats
for (int i = 0; i < k * n; i++)
distances[i] = idistances[i];
}
indexPQ_stats.nq += n;
indexPQ_stats.ncode += n * ntotal;
}
}
void IndexPQStats::reset() {
nq = ncode = n_hamming_pass = 0;
}
IndexPQStats indexPQ_stats;
template <class HammingComputer>
static size_t polysemous_inner_loop(
const IndexPQ& index,
const float* dis_table_qi,
const uint8_t* q_code,
size_t k,
float* heap_dis,
int64_t* heap_ids) {
int M = index.pq.M;
int code_size = index.pq.code_size;
int ksub = index.pq.ksub;
size_t ntotal = index.ntotal;
int ht = index.polysemous_ht;
const uint8_t* b_code = index.codes.data();
size_t n_pass_i = 0;
HammingComputer hc(q_code, code_size);
for (int64_t bi = 0; bi < ntotal; bi++) {
int hd = hc.hamming(b_code);
if (hd < ht) {
n_pass_i++;
float dis = 0;
const float* dis_table = dis_table_qi;
for (int m = 0; m < M; m++) {
dis += dis_table[b_code[m]];
dis_table += ksub;
}
if (dis < heap_dis[0]) {
maxheap_replace_top(k, heap_dis, heap_ids, dis, bi);
}
}
b_code += code_size;
}
return n_pass_i;
}
void IndexPQ::search_core_polysemous(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(pq.nbits == 8);
// PQ distance tables
float* dis_tables = new float[n * pq.ksub * pq.M];
ScopeDeleter<float> del(dis_tables);
pq.compute_distance_tables(n, x, dis_tables);
// Hamming embedding queries
uint8_t* q_codes = new uint8_t[n * pq.code_size];
ScopeDeleter<uint8_t> del2(q_codes);
if (false) {
pq.compute_codes(x, q_codes, n);
} else {
#pragma omp parallel for
for (idx_t qi = 0; qi < n; qi++) {
pq.compute_code_from_distance_table(
dis_tables + qi * pq.M * pq.ksub,
q_codes + qi * pq.code_size);
}
}
size_t n_pass = 0;
#pragma omp parallel for reduction(+ : n_pass)
for (idx_t qi = 0; qi < n; qi++) {
const uint8_t* q_code = q_codes + qi * pq.code_size;
const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
int64_t* heap_ids = labels + qi * k;
float* heap_dis = distances + qi * k;
maxheap_heapify(k, heap_dis, heap_ids);
if (search_type == ST_polysemous) {
switch (pq.code_size) {
case 4:
n_pass += polysemous_inner_loop<HammingComputer4>(
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 8:
n_pass += polysemous_inner_loop<HammingComputer8>(
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 16:
n_pass += polysemous_inner_loop<HammingComputer16>(
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 32:
n_pass += polysemous_inner_loop<HammingComputer32>(
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 20:
n_pass += polysemous_inner_loop<HammingComputer20>(
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
default:
if (pq.code_size % 4 == 0) {
n_pass += polysemous_inner_loop<HammingComputerDefault>(
*this,
dis_table_qi,
q_code,
k,
heap_dis,
heap_ids);
} else {
FAISS_THROW_FMT(
"code size %zd not supported for polysemous",
pq.code_size);
}
break;
}
} else {
switch (pq.code_size) {
case 8:
n_pass += polysemous_inner_loop<GenHammingComputer8>(
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 16:
n_pass += polysemous_inner_loop<GenHammingComputer16>(
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 32:
n_pass += polysemous_inner_loop<GenHammingComputer32>(
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
default:
if (pq.code_size % 8 == 0) {
n_pass += polysemous_inner_loop<GenHammingComputerM8>(
*this,
dis_table_qi,
q_code,
k,
heap_dis,
heap_ids);
} else {
FAISS_THROW_FMT(
"code size %zd not supported for polysemous",
pq.code_size);
}
break;
}
}
maxheap_reorder(k, heap_dis, heap_ids);
}
indexPQ_stats.nq += n;
indexPQ_stats.ncode += n * ntotal;
indexPQ_stats.n_hamming_pass += n_pass;
}
/* The standalone codec interface (just remaps to the PQ functions) */
void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
pq.compute_codes(x, bytes, n);
}
void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
pq.decode(bytes, x, n);
}
/*****************************************
* Stats of IndexPQ codes
******************************************/
void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis)
const {
uint8_t* q_codes = new uint8_t[n * pq.code_size];
ScopeDeleter<uint8_t> del(q_codes);
pq.compute_codes(x, q_codes, n);
hammings(q_codes, codes.data(), n, ntotal, pq.code_size, dis);
}
void IndexPQ::hamming_distance_histogram(
idx_t n,
const float* x,
idx_t nb,
const float* xb,
int64_t* hist) {
FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
FAISS_THROW_IF_NOT(pq.code_size % 8 == 0);
FAISS_THROW_IF_NOT(pq.nbits == 8);
// Hamming embedding queries
uint8_t* q_codes = new uint8_t[n * pq.code_size];
ScopeDeleter<uint8_t> del(q_codes);
pq.compute_codes(x, q_codes, n);
uint8_t* b_codes;
ScopeDeleter<uint8_t> del_b_codes;
if (xb) {
b_codes = new uint8_t[nb * pq.code_size];
del_b_codes.set(b_codes);
pq.compute_codes(xb, b_codes, nb);
} else {
nb = ntotal;
b_codes = codes.data();
}
int nbits = pq.M * pq.nbits;
memset(hist, 0, sizeof(*hist) * (nbits + 1));
size_t bs = 256;
#pragma omp parallel
{
std::vector<int64_t> histi(nbits + 1);
hamdis_t* distances = new hamdis_t[nb * bs];
ScopeDeleter<hamdis_t> del(distances);
#pragma omp for
for (idx_t q0 = 0; q0 < n; q0 += bs) {
// printf ("dis stats: %zd/%zd\n", q0, n);
size_t q1 = q0 + bs;
if (q1 > n)
q1 = n;
hammings(
q_codes + q0 * pq.code_size,
b_codes,
q1 - q0,
nb,
pq.code_size,
distances);
for (size_t i = 0; i < nb * (q1 - q0); i++)
histi[distances[i]]++;
}
#pragma omp critical
{
for (int i = 0; i <= nbits; i++)
hist[i] += histi[i];
}
}
}
/*****************************************
* MultiIndexQuantizer
******************************************/
namespace {
template <typename T>
struct PreSortedArray {
const T* x;
int N;
explicit PreSortedArray(int N) : N(N) {}
void init(const T* x) {
this->x = x;
}
// get smallest value
T get_0() {
return x[0];
}
// get delta between n-smallest and n-1 -smallest
T get_diff(int n) {
return x[n] - x[n - 1];
}
// remap orders counted from smallest to indices in array
int get_ord(int n) {
return n;
}
};
template <typename T>
struct ArgSort {
const T* x;
bool operator()(size_t i, size_t j) {
return x[i] < x[j];
}
};
/** Array that maintains a permutation of its elements so that the
* array's elements are sorted
*/
template <typename T>
struct SortedArray {
const T* x;
int N;
std::vector<int> perm;
explicit SortedArray(int N) {
this->N = N;
perm.resize(N);
}
void init(const T* x) {
this->x = x;
for (int n = 0; n < N; n++)
perm[n] = n;
ArgSort<T> cmp = {x};
std::sort(perm.begin(), perm.end(), cmp);
}
// get smallest value
T get_0() {
return x[perm[0]];
}
// get delta between n-smallest and n-1 -smallest
T get_diff(int n) {
return x[perm[n]] - x[perm[n - 1]];
}
// remap orders counted from smallest to indices in array
int get_ord(int n) {
return perm[n];
}
};
/** Array has n values. Sort the k first ones and copy the other ones
* into elements k..n-1
*/
template <class C>
void partial_sort(
int k,
int n,
const typename C::T* vals,
typename C::TI* perm) {
// insert first k elts in heap
for (int i = 1; i < k; i++) {
indirect_heap_push<C>(i + 1, vals, perm, perm[i]);
}
// insert next n - k elts in heap
for (int i = k; i < n; i++) {
typename C::TI id = perm[i];
typename C::TI top = perm[0];
if (C::cmp(vals[top], vals[id])) {
indirect_heap_pop<C>(k, vals, perm);
indirect_heap_push<C>(k, vals, perm, id);
perm[i] = top;
} else {
// nothing, elt at i is good where it is.
}
}
// order the k first elements in heap
for (int i = k - 1; i > 0; i--) {
typename C::TI top = perm[0];
indirect_heap_pop<C>(i + 1, vals, perm);
perm[i] = top;
}
}
/** same as SortedArray, but only the k first elements are sorted */
template <typename T>
struct SemiSortedArray {
const T* x;
int N;
// type of the heap: CMax = sort ascending
typedef CMax<T, int> HC;
std::vector<int> perm;
int k; // k elements are sorted
int initial_k, k_factor;
explicit SemiSortedArray(int N) {
this->N = N;
perm.resize(N);
perm.resize(N);
initial_k = 3;
k_factor = 4;
}
void init(const T* x) {
this->x = x;
for (int n = 0; n < N; n++)
perm[n] = n;
k = 0;
grow(initial_k);
}
/// grow the sorted part of the array to size next_k
void grow(int next_k) {
if (next_k < N) {
partial_sort<HC>(next_k - k, N - k, x, &perm[k]);
k = next_k;
} else { // full sort of remainder of array
ArgSort<T> cmp = {x};
std::sort(perm.begin() + k, perm.end(), cmp);
k = N;
}
}
// get smallest value
T get_0() {
return x[perm[0]];
}
// get delta between n-smallest and n-1 -smallest
T get_diff(int n) {
if (n >= k) {
// want to keep powers of 2 - 1
int next_k = (k + 1) * k_factor - 1;
grow(next_k);
}
return x[perm[n]] - x[perm[n - 1]];
}
// remap orders counted from smallest to indices in array
int get_ord(int n) {
assert(n < k);
return perm[n];
}
};
/*****************************************
* Find the k smallest sums of M terms, where each term is taken in a
* table x of n values.
*
* A combination of terms is encoded as a scalar 0 <= t < n^M. The
* combination t0 ... t(M-1) that correspond to the sum
*
* sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)]
*
* is encoded as
*
* t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1)
*
* MinSumK is an object rather than a function, so that storage can be
* re-used over several computations with the same sizes. use_seen is
* good when there may be ties in the x array and it is a concern if
* occasionally several t's are returned.
*
* @param x size M * n, values to add up
* @param k nb of results to retrieve
* @param M nb of terms
* @param n nb of distinct values
* @param sums output, size k, sorted
* @param terms output, size k, with encoding as above
*
******************************************/
template <typename T, class SSA, bool use_seen>
struct MinSumK {
int K; ///< nb of sums to return
int M; ///< nb of elements to sum up
int nbit; ///< nb of bits to encode one entry
int N; ///< nb of possible elements for each of the M terms
/** the heap.
* We use a heap to maintain a queue of sums, with the associated
* terms involved in the sum.
*/
typedef CMin<T, int64_t> HC;
size_t heap_capacity, heap_size;
T* bh_val;
int64_t* bh_ids;
std::vector<SSA> ssx;
// all results get pushed several times. When there are ties, they
// are popped interleaved with others, so it is not easy to
// identify them. Therefore, this bit array just marks elements
// that were seen before.
std::vector<uint8_t> seen;
MinSumK(int K, int M, int nbit, int N) : K(K), M(M), nbit(nbit), N(N) {
heap_capacity = K * M;
assert(N <= (1 << nbit));
// we'll do k steps, each step pushes at most M vals
bh_val = new T[heap_capacity];
bh_ids = new int64_t[heap_capacity];
if (use_seen) {
int64_t n_ids = weight(M);
seen.resize((n_ids + 7) / 8);
}
for (int m = 0; m < M; m++)
ssx.push_back(SSA(N));
}
int64_t weight(int i) {
return 1 << (i * nbit);
}
bool is_seen(int64_t i) {
return (seen[i >> 3] >> (i & 7)) & 1;
}
void mark_seen(int64_t i) {
if (use_seen)
seen[i >> 3] |= 1 << (i & 7);
}
void run(const T* x, int64_t ldx, T* sums, int64_t* terms) {
heap_size = 0;
for (int m = 0; m < M; m++) {
ssx[m].init(x);
x += ldx;
}
{ // initial result: take min for all elements
T sum = 0;
terms[0] = 0;
mark_seen(0);
for (int m = 0; m < M; m++) {
sum += ssx[m].get_0();
}
sums[0] = sum;
for (int m = 0; m < M; m++) {
heap_push<HC>(
++heap_size,
bh_val,
bh_ids,
sum + ssx[m].get_diff(1),
weight(m));
}
}
for (int k = 1; k < K; k++) {
// pop smallest value from heap
if (use_seen) { // skip already seen elements
while (is_seen(bh_ids[0])) {
assert(heap_size > 0);
heap_pop<HC>(heap_size--, bh_val, bh_ids);
}
}
assert(heap_size > 0);
T sum = sums[k] = bh_val[0];
int64_t ti = terms[k] = bh_ids[0];
if (use_seen) {
mark_seen(ti);
heap_pop<HC>(heap_size--, bh_val, bh_ids);
} else {
do {
heap_pop<HC>(heap_size--, bh_val, bh_ids);
} while (heap_size > 0 && bh_ids[0] == ti);
}
// enqueue followers
int64_t ii = ti;
for (int m = 0; m < M; m++) {
int64_t n = ii & ((1L << nbit) - 1);
ii >>= nbit;
if (n + 1 >= N)
continue;
enqueue_follower(ti, m, n, sum);
}
}
/*
for (int k = 0; k < K; k++)
for (int l = k + 1; l < K; l++)
assert (terms[k] != terms[l]);
*/
// convert indices by applying permutation
for (int k = 0; k < K; k++) {
int64_t ii = terms[k];
if (use_seen) {
// clear seen for reuse at next loop
seen[ii >> 3] = 0;
}
int64_t ti = 0;
for (int m = 0; m < M; m++) {
int64_t n = ii & ((1L << nbit) - 1);
ti += int64_t(ssx[m].get_ord(n)) << (nbit * m);
ii >>= nbit;
}
terms[k] = ti;
}
}
void enqueue_follower(int64_t ti, int m, int n, T sum) {
T next_sum = sum + ssx[m].get_diff(n + 1);
int64_t next_ti = ti + weight(m);
heap_push<HC>(++heap_size, bh_val, bh_ids, next_sum, next_ti);
}
~MinSumK() {
delete[] bh_ids;
delete[] bh_val;
}
};
} // anonymous namespace
MultiIndexQuantizer::MultiIndexQuantizer(int d, size_t M, size_t nbits)
: Index(d, METRIC_L2), pq(d, M, nbits) {
is_trained = false;
pq.verbose = verbose;
}
void MultiIndexQuantizer::train(idx_t n, const float* x) {
pq.verbose = verbose;
pq.train(n, x);
is_trained = true;
// count virtual elements in index
ntotal = 1;
for (int m = 0; m < pq.M; m++)
ntotal *= pq.ksub;
}
void MultiIndexQuantizer::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
if (n == 0)
return;
FAISS_THROW_IF_NOT(k > 0);
// the allocation just below can be severe...
idx_t bs = 32768;
if (n > bs) {
for (idx_t i0 = 0; i0 < n; i0 += bs) {
idx_t i1 = std::min(i0 + bs, n);
if (verbose) {
printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64
" / %" PRId64 "\n",
i0,
i1,
n);
}
search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
}
return;
}
float* dis_tables = new float[n * pq.ksub * pq.M];
ScopeDeleter<float> del(dis_tables);
pq.compute_distance_tables(n, x, dis_tables);
if (k == 1) {
// simple version that just finds the min in each table
#pragma omp parallel for
for (int i = 0; i < n; i++) {
const float* dis_table = dis_tables + i * pq.ksub * pq.M;
float dis = 0;
idx_t label = 0;
for (int s = 0; s < pq.M; s++) {
float vmin = HUGE_VALF;
idx_t lmin = -1;
for (idx_t j = 0; j < pq.ksub; j++) {
if (dis_table[j] < vmin) {
vmin = dis_table[j];
lmin = j;
}
}
dis += vmin;
label |= lmin << (s * pq.nbits);
dis_table += pq.ksub;
}
distances[i] = dis;
labels[i] = label;
}
} else {
#pragma omp parallel if (n > 1)
{
MinSumK<float, SemiSortedArray<float>, false> msk(
k, pq.M, pq.nbits, pq.ksub);
#pragma omp for
for (int i = 0; i < n; i++) {
msk.run(dis_tables + i * pq.ksub * pq.M,
pq.ksub,
distances + i * k,
labels + i * k);
}
}
}
}
void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const {
int64_t jj = key;
for (int m = 0; m < pq.M; m++) {
int64_t n = jj & ((1L << pq.nbits) - 1);
jj >>= pq.nbits;
memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
recons += pq.dsub;
}
}
void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
FAISS_THROW_MSG(
"This index has virtual elements, "
"it does not support add");
}
void MultiIndexQuantizer::reset() {
FAISS_THROW_MSG(
"This index has virtual elements, "
"it does not support reset");
}
/*****************************************
* MultiIndexQuantizer2
******************************************/
MultiIndexQuantizer2::MultiIndexQuantizer2(
int d,
size_t M,
size_t nbits,
Index** indexes)
: MultiIndexQuantizer(d, M, nbits) {
assign_indexes.resize(M);
for (int i = 0; i < M; i++) {
FAISS_THROW_IF_NOT_MSG(
indexes[i]->d == pq.dsub,
"Provided sub-index has incorrect size");
assign_indexes[i] = indexes[i];
}
own_fields = false;
}
MultiIndexQuantizer2::MultiIndexQuantizer2(
int d,
size_t nbits,
Index* assign_index_0,
Index* assign_index_1)
: MultiIndexQuantizer(d, 2, nbits) {
FAISS_THROW_IF_NOT_MSG(
assign_index_0->d == pq.dsub && assign_index_1->d == pq.dsub,
"Provided sub-index has incorrect size");
assign_indexes.resize(2);
assign_indexes[0] = assign_index_0;
assign_indexes[1] = assign_index_1;
own_fields = false;
}
void MultiIndexQuantizer2::train(idx_t n, const float* x) {
MultiIndexQuantizer::train(n, x);
// add centroids to sub-indexes
for (int i = 0; i < pq.M; i++) {
assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0));
}
}
void MultiIndexQuantizer2::search(
idx_t n,
const float* x,
idx_t K,
float* distances,
idx_t* labels) const {
if (n == 0)
return;
int k2 = std::min(K, int64_t(pq.ksub));
FAISS_THROW_IF_NOT(k2);
int64_t M = pq.M;
int64_t dsub = pq.dsub, ksub = pq.ksub;
// size (M, n, k2)
std::vector<idx_t> sub_ids(n * M * k2);
std::vector<float> sub_dis(n * M * k2);
std::vector<float> xsub(n * dsub);
for (int m = 0; m < M; m++) {
float* xdest = xsub.data();
const float* xsrc = x + m * dsub;
for (int j = 0; j < n; j++) {
memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
xsrc += d;
xdest += dsub;
}
assign_indexes[m]->search(
n, xsub.data(), k2, &sub_dis[k2 * n * m], &sub_ids[k2 * n * m]);
}
if (K == 1) {
// simple version that just finds the min in each table
assert(k2 == 1);
for (int i = 0; i < n; i++) {
float dis = 0;
idx_t label = 0;
for (int m = 0; m < M; m++) {
float vmin = sub_dis[i + m * n];
idx_t lmin = sub_ids[i + m * n];
dis += vmin;
label |= lmin << (m * pq.nbits);
}
distances[i] = dis;
labels[i] = label;
}
} else {
#pragma omp parallel if (n > 1)
{
MinSumK<float, PreSortedArray<float>, false> msk(
K, pq.M, pq.nbits, k2);
#pragma omp for
for (int i = 0; i < n; i++) {
idx_t* li = labels + i * K;
msk.run(&sub_dis[i * k2], k2 * n, distances + i * K, li);
// remap ids
const idx_t* idmap0 = sub_ids.data() + i * k2;
int64_t ld_idmap = k2 * n;
int64_t mask1 = ksub - 1L;
for (int k = 0; k < K; k++) {
const idx_t* idmap = idmap0;
int64_t vin = li[k];
int64_t vout = 0;
int bs = 0;
for (int m = 0; m < M; m++) {
int64_t s = vin & mask1;
vin >>= pq.nbits;
vout |= idmap[s] << bs;
bs += pq.nbits;
idmap += ld_idmap;
}
li[k] = vout;
}
}
}
}
}
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#ifndef FAISS_INDEX_PQ_H
#define FAISS_INDEX_PQ_H
#include <stdint.h>
#include <vector>
#include <faiss/IndexFlatCodes.h>
#include <faiss/impl/PolysemousTraining.h>
#include <faiss/impl/ProductQuantizer.h>
#include <faiss/impl/platform_macros.h>
namespace faiss {
/** Index based on a product quantizer. Stored vectors are
* approximated by PQ codes. */
struct IndexPQ : IndexFlatCodes {
/// The product quantizer used to encode the vectors
ProductQuantizer pq;
/** Constructor.
*
* @param d dimensionality of the input vectors
* @param M number of subquantizers
* @param nbits number of bit per subvector index
*/
IndexPQ(int d, ///< dimensionality of the input vectors
size_t M, ///< number of subquantizers
size_t nbits, ///< number of bit per subvector index
MetricType metric = METRIC_L2);
IndexPQ();
void train(idx_t n, const float* x) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
/* The standalone codec interface */
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
DistanceComputer* get_distance_computer() const override;
/******************************************************
* Polysemous codes implementation
******************************************************/
bool do_polysemous_training; ///< false = standard PQ
/// parameters used for the polysemous training
PolysemousTraining polysemous_training;
/// how to perform the search in search_core
enum Search_type_t {
ST_PQ, ///< asymmetric product quantizer (default)
ST_HE, ///< Hamming distance on codes
ST_generalized_HE, ///< nb of same codes
ST_SDC, ///< symmetric product quantizer (SDC)
ST_polysemous, ///< HE filter (using ht) + PQ combination
ST_polysemous_generalize, ///< Filter on generalized Hamming
};
Search_type_t search_type;
// just encode the sign of the components, instead of using the PQ encoder
// used only for the queries
bool encode_signs;
/// Hamming threshold used for polysemy
int polysemous_ht;
// actual polysemous search
void search_core_polysemous(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const;
/// prepare query for a polysemous search, but instead of
/// computing the result, just get the histogram of Hamming
/// distances. May be computed on a provided dataset if xb != NULL
/// @param dist_histogram (M * nbits + 1)
void hamming_distance_histogram(
idx_t n,
const float* x,
idx_t nb,
const float* xb,
int64_t* dist_histogram);
/** compute pairwise distances between queries and database
*
* @param n nb of query vectors
* @param x query vector, size n * d
* @param dis output distances, size n * ntotal
*/
void hamming_distance_table(idx_t n, const float* x, int32_t* dis) const;
};
/// statistics are robust to internal threading, but not if
/// IndexPQ::search is called by multiple threads
struct IndexPQStats {
size_t nq; // nb of queries run
size_t ncode; // nb of codes visited
size_t n_hamming_pass; // nb of passed Hamming distance tests (for polysemy)
IndexPQStats() {
reset();
}
void reset();
};
FAISS_API extern IndexPQStats indexPQ_stats;
/** Quantizer where centroids are virtual: they are the Cartesian
* product of sub-centroids. */
struct MultiIndexQuantizer : Index {
ProductQuantizer pq;
MultiIndexQuantizer(
int d, ///< dimension of the input vectors
size_t M, ///< number of subquantizers
size_t nbits); ///< number of bit per subvector index
void train(idx_t n, const float* x) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
/// add and reset will crash at runtime
void add(idx_t n, const float* x) override;
void reset() override;
MultiIndexQuantizer() {}
void reconstruct(idx_t key, float* recons) const override;
};
/** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes
*/
struct MultiIndexQuantizer2 : MultiIndexQuantizer {
/// M Indexes on d / M dimensions
std::vector<Index*> assign_indexes;
bool own_fields;
MultiIndexQuantizer2(int d, size_t M, size_t nbits, Index** indexes);
MultiIndexQuantizer2(
int d,
size_t nbits,
Index* assign_index_0,
Index* assign_index_1);
void train(idx_t n, const float* x) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
};
} // namespace faiss
#endif
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/IndexPQFastScan.h>
#include <limits.h>
#include <cassert>
#include <memory>
#include <omp.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/random.h>
#include <faiss/utils/utils.h>
#include <faiss/impl/pq4_fast_scan.h>
#include <faiss/impl/simd_result_handlers.h>
#include <faiss/utils/quantize_lut.h>
namespace faiss {
using namespace simd_result_handlers;
inline size_t roundup(size_t a, size_t b) {
return (a + b - 1) / b * b;
}
IndexPQFastScan::IndexPQFastScan(
int d,
size_t M,
size_t nbits,
MetricType metric,
int bbs)
: Index(d, metric),
pq(d, M, nbits),
bbs(bbs),
ntotal2(0),
M2(roundup(M, 2)) {
FAISS_THROW_IF_NOT(nbits == 4);
is_trained = false;
}
IndexPQFastScan::IndexPQFastScan() : bbs(0), ntotal2(0), M2(0) {}
IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs)
: Index(orig.d, orig.metric_type), pq(orig.pq), bbs(bbs) {
FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
ntotal = orig.ntotal;
is_trained = orig.is_trained;
orig_codes = orig.codes.data();
qbs = 0; // means use default
// pack the codes
size_t M = pq.M;
FAISS_THROW_IF_NOT(bbs % 32 == 0);
M2 = roundup(M, 2);
ntotal2 = roundup(ntotal, bbs);
codes.resize(ntotal2 * M2 / 2);
// printf("M=%d M2=%d code_size=%d\n", M, M2, pq.code_size);
pq4_pack_codes(orig.codes.data(), ntotal, M, ntotal2, bbs, M2, codes.get());
}
void IndexPQFastScan::train(idx_t n, const float* x) {
if (is_trained) {
return;
}
pq.train(n, x);
is_trained = true;
}
void IndexPQFastScan::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT(is_trained);
AlignedTable<uint8_t> tmp_codes(n * pq.code_size);
pq.compute_codes(x, tmp_codes.get(), n);
ntotal2 = roundup(ntotal + n, bbs);
size_t new_size = ntotal2 * M2 / 2;
size_t old_size = codes.size();
if (new_size > old_size) {
codes.resize(new_size);
memset(codes.get() + old_size, 0, new_size - old_size);
}
pq4_pack_codes_range(
tmp_codes.get(), pq.M, ntotal, ntotal + n, bbs, M2, codes.get());
ntotal += n;
}
void IndexPQFastScan::reset() {
codes.resize(0);
ntotal = 0;
}
namespace {
// from impl/ProductQuantizer.cpp
template <class C, typename dis_t>
void pq_estimators_from_tables_generic(
const ProductQuantizer& pq,
size_t nbits,
const uint8_t* codes,
size_t ncodes,
const dis_t* dis_table,
size_t k,
typename C::T* heap_dis,
int64_t* heap_ids) {
using accu_t = typename C::T;
const size_t M = pq.M;
const size_t ksub = pq.ksub;
for (size_t j = 0; j < ncodes; ++j) {
PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
accu_t dis = 0;
const dis_t* __restrict dt = dis_table;
for (size_t m = 0; m < M; m++) {
uint64_t c = decoder.decode();
dis += dt[c];
dt += ksub;
}
if (C::cmp(heap_dis[0], dis)) {
heap_pop<C>(k, heap_dis, heap_ids);
heap_push<C>(k, heap_dis, heap_ids, dis, j);
}
}
}
} // anonymous namespace
using namespace quantize_lut;
void IndexPQFastScan::compute_quantized_LUT(
idx_t n,
const float* x,
uint8_t* lut,
float* normalizers) const {
size_t dim12 = pq.ksub * pq.M;
std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
if (metric_type == METRIC_L2) {
pq.compute_distance_tables(n, x, dis_tables.get());
} else {
pq.compute_inner_prod_tables(n, x, dis_tables.get());
}
for (uint64_t i = 0; i < n; i++) {
round_uint8_per_column(
dis_tables.get() + i * dim12,
pq.M,
pq.ksub,
&normalizers[2 * i],
&normalizers[2 * i + 1]);
}
for (uint64_t i = 0; i < n; i++) {
const float* t_in = dis_tables.get() + i * dim12;
uint8_t* t_out = lut + i * M2 * pq.ksub;
for (int j = 0; j < dim12; j++) {
t_out[j] = int(t_in[j]);
}
memset(t_out + dim12, 0, (M2 - pq.M) * pq.ksub);
}
}
/******************************************************************************
* Search driver routine
******************************************************************************/
void IndexPQFastScan::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT(k > 0);
if (metric_type == METRIC_L2) {
search_dispatch_implem<true>(n, x, k, distances, labels);
} else {
search_dispatch_implem<false>(n, x, k, distances, labels);
}
}
template <bool is_max>
void IndexPQFastScan::search_dispatch_implem(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
using Cfloat = typename std::conditional<
is_max,
CMax<float, int64_t>,
CMin<float, int64_t>>::type;
using C = typename std::
conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type;
if (n == 0) {
return;
}
// actual implementation used
int impl = implem;
if (impl == 0) {
if (bbs == 32) {
impl = 12;
} else {
impl = 14;
}
if (k > 20) {
impl++;
}
}
if (implem == 1) {
FAISS_THROW_IF_NOT(orig_codes);
FAISS_THROW_IF_NOT(is_max);
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
pq.search(x, n, orig_codes, ntotal, &res, true);
} else if (implem == 2 || implem == 3 || implem == 4) {
FAISS_THROW_IF_NOT(orig_codes);
size_t dim12 = pq.ksub * pq.M;
std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
if (is_max) {
pq.compute_distance_tables(n, x, dis_tables.get());
} else {
pq.compute_inner_prod_tables(n, x, dis_tables.get());
}
std::vector<float> normalizers(n * 2);
if (implem == 2) {
// default float
} else if (implem == 3 || implem == 4) {
for (uint64_t i = 0; i < n; i++) {
round_uint8_per_column(
dis_tables.get() + i * dim12,
pq.M,
pq.ksub,
&normalizers[2 * i],
&normalizers[2 * i + 1]);
}
}
for (int64_t i = 0; i < n; i++) {
int64_t* heap_ids = labels + i * k;
float* heap_dis = distances + i * k;
heap_heapify<Cfloat>(k, heap_dis, heap_ids);
pq_estimators_from_tables_generic<Cfloat>(
pq,
pq.nbits,
orig_codes,
ntotal,
dis_tables.get() + i * dim12,
k,
heap_dis,
heap_ids);
heap_reorder<Cfloat>(k, heap_dis, heap_ids);
if (implem == 4) {
float a = normalizers[2 * i];
float b = normalizers[2 * i + 1];
for (int j = 0; j < k; j++) {
heap_dis[j] = heap_dis[j] / a + b;
}
}
}
} else if (impl >= 12 && impl <= 15) {
FAISS_THROW_IF_NOT(ntotal < INT_MAX);
int nt = std::min(omp_get_max_threads(), int(n));
if (nt < 2) {
if (impl == 12 || impl == 13) {
search_implem_12<C>(n, x, k, distances, labels, impl);
} else {
search_implem_14<C>(n, x, k, distances, labels, impl);
}
} else {
// explicitly slice over threads
#pragma omp parallel for num_threads(nt)
for (int slice = 0; slice < nt; slice++) {
idx_t i0 = n * slice / nt;
idx_t i1 = n * (slice + 1) / nt;
float* dis_i = distances + i0 * k;
idx_t* lab_i = labels + i0 * k;
if (impl == 12 || impl == 13) {
search_implem_12<C>(
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
} else {
search_implem_14<C>(
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
}
}
}
} else {
FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
}
}
template <class C>
void IndexPQFastScan::search_implem_12(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
int impl) const {
FAISS_THROW_IF_NOT(bbs == 32);
// handle qbs2 blocking by recursive call
int64_t qbs2 = this->qbs == 0 ? 11 : pq4_qbs_to_nq(this->qbs);
if (n > qbs2) {
for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
int64_t i1 = std::min(i0 + qbs2, n);
search_implem_12<C>(
i1 - i0,
x + d * i0,
k,
distances + i0 * k,
labels + i0 * k,
impl);
}
return;
}
size_t dim12 = pq.ksub * M2;
AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
std::unique_ptr<float[]> normalizers(new float[2 * n]);
if (skip & 1) {
quantized_dis_tables.clear();
} else {
compute_quantized_LUT(
n, x, quantized_dis_tables.get(), normalizers.get());
}
AlignedTable<uint8_t> LUT(n * dim12);
// block sizes are encoded in qbs, 4 bits at a time
// caution: we override an object field
int qbs = this->qbs;
if (n != pq4_qbs_to_nq(qbs)) {
qbs = pq4_preferred_qbs(n);
}
int LUT_nq =
pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
FAISS_THROW_IF_NOT(LUT_nq == n);
if (k == 1) {
SingleResultHandler<C> handler(n, ntotal);
if (skip & 4) {
// pass
} else {
handler.disable = bool(skip & 2);
pq4_accumulate_loop_qbs(
qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
}
handler.to_flat_arrays(distances, labels, normalizers.get());
} else if (impl == 12) {
std::vector<uint16_t> tmp_dis(n * k);
std::vector<int32_t> tmp_ids(n * k);
if (skip & 4) {
// skip
} else {
HeapHandler<C> handler(
n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
handler.disable = bool(skip & 2);
pq4_accumulate_loop_qbs(
qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
if (!(skip & 8)) {
handler.to_flat_arrays(distances, labels, normalizers.get());
}
}
} else { // impl == 13
ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
handler.disable = bool(skip & 2);
if (skip & 4) {
// skip
} else {
pq4_accumulate_loop_qbs(
qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
}
if (!(skip & 8)) {
handler.to_flat_arrays(distances, labels, normalizers.get());
}
FastScan_stats.t0 += handler.times[0];
FastScan_stats.t1 += handler.times[1];
FastScan_stats.t2 += handler.times[2];
FastScan_stats.t3 += handler.times[3];
}
}
FastScanStats FastScan_stats;
template <class C>
void IndexPQFastScan::search_implem_14(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
int impl) const {
FAISS_THROW_IF_NOT(bbs % 32 == 0);
int qbs2 = qbs == 0 ? 4 : qbs;
// handle qbs2 blocking by recursive call
if (n > qbs2) {
for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
int64_t i1 = std::min(i0 + qbs2, n);
search_implem_14<C>(
i1 - i0,
x + d * i0,
k,
distances + i0 * k,
labels + i0 * k,
impl);
}
return;
}
size_t dim12 = pq.ksub * M2;
AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
std::unique_ptr<float[]> normalizers(new float[2 * n]);
if (skip & 1) {
quantized_dis_tables.clear();
} else {
compute_quantized_LUT(
n, x, quantized_dis_tables.get(), normalizers.get());
}
AlignedTable<uint8_t> LUT(n * dim12);
pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
if (k == 1) {
SingleResultHandler<C> handler(n, ntotal);
if (skip & 4) {
// pass
} else {
handler.disable = bool(skip & 2);
pq4_accumulate_loop(
n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
}
handler.to_flat_arrays(distances, labels, normalizers.get());
} else if (impl == 14) {
std::vector<uint16_t> tmp_dis(n * k);
std::vector<int32_t> tmp_ids(n * k);
if (skip & 4) {
// skip
} else if (k > 1) {
HeapHandler<C> handler(
n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
handler.disable = bool(skip & 2);
pq4_accumulate_loop(
n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
if (!(skip & 8)) {
handler.to_flat_arrays(distances, labels, normalizers.get());
}
}
} else { // impl == 15
ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
handler.disable = bool(skip & 2);
if (skip & 4) {
// skip
} else {
pq4_accumulate_loop(
n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
}
if (!(skip & 8)) {
handler.to_flat_arrays(distances, labels, normalizers.get());
}
}
}
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/IndexPQ.h>
#include <faiss/impl/ProductQuantizer.h>
#include <faiss/utils/AlignedTable.h>
namespace faiss {
/** Fast scan version of IndexPQ. Works for 4-bit PQ for now.
*
* The codes are not stored sequentially but grouped in blocks of size bbs.
* This makes it possible to compute distances quickly with SIMD instructions.
*
* Implementations:
* 12: blocked loop with internal loop on Q with qbs
* 13: same with reservoir accumulator to store results
* 14: no qbs with heap accumulator
* 15: no qbs with reservoir accumulator
*/
struct IndexPQFastScan : Index {
ProductQuantizer pq;
// implementation to select
int implem = 0;
// skip some parts of the computation (for timing)
int skip = 0;
// size of the kernel
int bbs; // set at build time
int qbs = 0; // query block size 0 = use default
// packed version of the codes
size_t ntotal2;
size_t M2;
AlignedTable<uint8_t> codes;
// this is for testing purposes only (set when initialized by IndexPQ)
const uint8_t* orig_codes = nullptr;
IndexPQFastScan(
int d,
size_t M,
size_t nbits,
MetricType metric = METRIC_L2,
int bbs = 32);
IndexPQFastScan();
/// build from an existing IndexPQ
explicit IndexPQFastScan(const IndexPQ& orig, int bbs = 32);
void train(idx_t n, const float* x) override;
void add(idx_t n, const float* x) override;
void reset() override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
// called by search function
void compute_quantized_LUT(
idx_t n,
const float* x,
uint8_t* lut,
float* normalizers) const;
template <bool is_max>
void search_dispatch_implem(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const;
template <class C>
void search_implem_2(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const;
template <class C>
void search_implem_12(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
int impl) const;
template <class C>
void search_implem_14(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
int impl) const;
};
struct FastScanStats {
uint64_t t0, t1, t2, t3;
FastScanStats() {
reset();
}
void reset() {
memset(this, 0, sizeof(*this));
}
};
FAISS_API extern FastScanStats FastScan_stats;
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexPreTransform.h>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <memory>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
namespace faiss {
/*********************************************
* IndexPreTransform
*********************************************/
IndexPreTransform::IndexPreTransform() : index(nullptr), own_fields(false) {}
IndexPreTransform::IndexPreTransform(Index* index)
: Index(index->d, index->metric_type), index(index), own_fields(false) {
is_trained = index->is_trained;
ntotal = index->ntotal;
}
IndexPreTransform::IndexPreTransform(VectorTransform* ltrans, Index* index)
: Index(index->d, index->metric_type), index(index), own_fields(false) {
is_trained = index->is_trained;
ntotal = index->ntotal;
prepend_transform(ltrans);
}
void IndexPreTransform::prepend_transform(VectorTransform* ltrans) {
FAISS_THROW_IF_NOT(ltrans->d_out == d);
is_trained = is_trained && ltrans->is_trained;
chain.insert(chain.begin(), ltrans);
d = ltrans->d_in;
}
IndexPreTransform::~IndexPreTransform() {
if (own_fields) {
for (int i = 0; i < chain.size(); i++)
delete chain[i];
delete index;
}
}
void IndexPreTransform::train(idx_t n, const float* x) {
int last_untrained = 0;
if (!index->is_trained) {
last_untrained = chain.size();
} else {
for (int i = chain.size() - 1; i >= 0; i--) {
if (!chain[i]->is_trained) {
last_untrained = i;
break;
}
}
}
const float* prev_x = x;
ScopeDeleter<float> del;
if (verbose) {
printf("IndexPreTransform::train: training chain 0 to %d\n",
last_untrained);
}
for (int i = 0; i <= last_untrained; i++) {
if (i < chain.size()) {
VectorTransform* ltrans = chain[i];
if (!ltrans->is_trained) {
if (verbose) {
printf(" Training chain component %d/%zd\n",
i,
chain.size());
if (OPQMatrix* opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
opqm->verbose = true;
}
}
ltrans->train(n, prev_x);
}
} else {
if (verbose) {
printf(" Training sub-index\n");
}
index->train(n, prev_x);
}
if (i == last_untrained)
break;
if (verbose) {
printf(" Applying transform %d/%zd\n", i, chain.size());
}
float* xt = chain[i]->apply(n, prev_x);
if (prev_x != x)
delete[] prev_x;
prev_x = xt;
del.set(xt);
}
is_trained = true;
}
const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
const float* prev_x = x;
ScopeDeleter<float> del;
for (int i = 0; i < chain.size(); i++) {
float* xt = chain[i]->apply(n, prev_x);
ScopeDeleter<float> del2(xt);
del2.swap(del);
prev_x = xt;
}
del.release();
return prev_x;
}
void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
const {
const float* next_x = xt;
ScopeDeleter<float> del;
for (int i = chain.size() - 1; i >= 0; i--) {
float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
chain[i]->reverse_transform(n, next_x, prev_x);
del2.swap(del);
next_x = prev_x;
}
}
void IndexPreTransform::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
index->add(n, xt);
ntotal = index->ntotal;
}
void IndexPreTransform::add_with_ids(
idx_t n,
const float* x,
const idx_t* xids) {
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
index->add_with_ids(n, xt, xids);
ntotal = index->ntotal;
}
void IndexPreTransform::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
index->search(n, xt, k, distances, labels);
}
void IndexPreTransform::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result) const {
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
index->range_search(n, xt, radius, result);
}
void IndexPreTransform::reset() {
index->reset();
ntotal = 0;
}
size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
size_t nremove = index->remove_ids(sel);
ntotal = index->ntotal;
return nremove;
}
void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
float* x = chain.empty() ? recons : new float[index->d];
ScopeDeleter<float> del(recons == x ? nullptr : x);
// Initial reconstruction
index->reconstruct(key, x);
// Revert transformations from last to first
reverse_chain(1, x, recons);
}
void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
float* x = chain.empty() ? recons : new float[ni * index->d];
ScopeDeleter<float> del(recons == x ? nullptr : x);
// Initial reconstruction
index->reconstruct_n(i0, ni, x);
// Revert transformations from last to first
reverse_chain(ni, x, recons);
}
void IndexPreTransform::search_and_reconstruct(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
float* recons) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del((xt == x) ? nullptr : xt);
float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
index->search_and_reconstruct(n, xt, k, distances, labels, recons_temp);
// Revert transformations from last to first
reverse_chain(n * k, recons_temp, recons);
}
size_t IndexPreTransform::sa_code_size() const {
return index->sa_code_size();
}
void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
const {
if (chain.empty()) {
index->sa_encode(n, x, bytes);
} else {
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
index->sa_encode(n, xt, bytes);
}
}
void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
const {
if (chain.empty()) {
index->sa_decode(n, bytes, x);
} else {
std::unique_ptr<float[]> x1(new float[index->d * n]);
index->sa_decode(n, bytes, x1.get());
// Revert transformations from last to first
reverse_chain(n, x1.get(), x);
}
}
namespace {
struct PreTransformDistanceComputer : DistanceComputer {
const IndexPreTransform* index;
std::unique_ptr<DistanceComputer> sub_dc;
std::unique_ptr<const float[]> query;
explicit PreTransformDistanceComputer(const IndexPreTransform* index)
: index(index), sub_dc(index->index->get_distance_computer()) {}
void set_query(const float* x) override {
const float* xt = index->apply_chain(1, x);
if (xt == x) {
sub_dc->set_query(x);
} else {
query.reset(xt);
sub_dc->set_query(xt);
}
}
float symmetric_dis(idx_t i, idx_t j) override {
return sub_dc->symmetric_dis(i, j);
}
float operator()(idx_t i) override {
return (*sub_dc)(i);
}
};
} // anonymous namespace
DistanceComputer* IndexPreTransform::get_distance_computer() const {
if (chain.empty()) {
return index->get_distance_computer();
} else {
return new PreTransformDistanceComputer(this);
}
}
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#pragma once
#include <faiss/Index.h>
#include <faiss/VectorTransform.h>
namespace faiss {
/** Index that applies a LinearTransform transform on vectors before
* handing them over to a sub-index */
struct IndexPreTransform : Index {
std::vector<VectorTransform*> chain; ///! chain of tranforms
Index* index; ///! the sub-index
bool own_fields; ///! whether pointers are deleted in destructor
explicit IndexPreTransform(Index* index);
IndexPreTransform();
/// ltrans is the last transform before the index
IndexPreTransform(VectorTransform* ltrans, Index* index);
void prepend_transform(VectorTransform* ltrans);
void train(idx_t n, const float* x) override;
void add(idx_t n, const float* x) override;
void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
void reset() override;
/** removes IDs from the index. Not supported by all indexes.
*/
size_t remove_ids(const IDSelector& sel) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
/* range search, no attempt is done to change the radius */
void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result) const override;
void reconstruct(idx_t key, float* recons) const override;
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
void search_and_reconstruct(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
float* recons) const override;
/// apply the transforms in the chain. The returned float * may be
/// equal to x, otherwise it should be deallocated.
const float* apply_chain(idx_t n, const float* x) const;
/// Reverse the transforms in the chain. May not be implemented for
/// all transforms in the chain or may return approximate results.
void reverse_chain(idx_t n, const float* xt, float* x) const;
DistanceComputer* get_distance_computer() const override;
/* standalone codec interface */
size_t sa_code_size() const override;
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
~IndexPreTransform() override;
};
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/IndexRefine.h>
#include <faiss/IndexFlat.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/utils.h>
namespace faiss {
/***************************************************
* IndexRefine
***************************************************/
IndexRefine::IndexRefine(Index* base_index, Index* refine_index)
: Index(base_index->d, base_index->metric_type),
base_index(base_index),
refine_index(refine_index) {
own_fields = own_refine_index = false;
if (refine_index != nullptr) {
FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
FAISS_THROW_IF_NOT(
base_index->metric_type == refine_index->metric_type);
is_trained = base_index->is_trained && refine_index->is_trained;
FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal);
} // other case is useful only to construct an IndexRefineFlat
ntotal = base_index->ntotal;
}
IndexRefine::IndexRefine()
: base_index(nullptr),
refine_index(nullptr),
own_fields(false),
own_refine_index(false) {}
void IndexRefine::train(idx_t n, const float* x) {
base_index->train(n, x);
refine_index->train(n, x);
is_trained = true;
}
void IndexRefine::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT(is_trained);
base_index->add(n, x);
refine_index->add(n, x);
ntotal = refine_index->ntotal;
}
void IndexRefine::reset() {
base_index->reset();
refine_index->reset();
ntotal = 0;
}
namespace {
typedef faiss::Index::idx_t idx_t;
template <class C>
static void reorder_2_heaps(
idx_t n,
idx_t k,
idx_t* labels,
float* distances,
idx_t k_base,
const idx_t* base_labels,
const float* base_distances) {
#pragma omp parallel for
for (idx_t i = 0; i < n; i++) {
idx_t* idxo = labels + i * k;
float* diso = distances + i * k;
const idx_t* idxi = base_labels + i * k_base;
const float* disi = base_distances + i * k_base;
heap_heapify<C>(k, diso, idxo, disi, idxi, k);
if (k_base != k) { // add remaining elements
heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
}
heap_reorder<C>(k, diso, idxo);
}
}
} // anonymous namespace
void IndexRefine::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);
idx_t k_base = idx_t(k * k_factor);
idx_t* base_labels = labels;
float* base_distances = distances;
ScopeDeleter<idx_t> del1;
ScopeDeleter<float> del2;
if (k != k_base) {
base_labels = new idx_t[n * k_base];
del1.set(base_labels);
base_distances = new float[n * k_base];
del2.set(base_distances);
}
base_index->search(n, x, k_base, base_distances, base_labels);
for (int i = 0; i < n * k_base; i++)
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
// parallelize over queries
#pragma omp parallel if (n > 1)
{
std::unique_ptr<DistanceComputer> dc(
refine_index->get_distance_computer());
#pragma omp for
for (idx_t i = 0; i < n; i++) {
dc->set_query(x + i * d);
idx_t ij = i * k_base;
for (idx_t j = 0; j < k_base; j++) {
idx_t idx = base_labels[ij];
if (idx < 0)
break;
base_distances[ij] = (*dc)(idx);
ij++;
}
}
}
// sort and store result
if (metric_type == METRIC_L2) {
typedef CMax<float, idx_t> C;
reorder_2_heaps<C>(
n, k, labels, distances, k_base, base_labels, base_distances);
} else if (metric_type == METRIC_INNER_PRODUCT) {
typedef CMin<float, idx_t> C;
reorder_2_heaps<C>(
n, k, labels, distances, k_base, base_labels, base_distances);
} else {
FAISS_THROW_MSG("Metric type not supported");
}
}
void IndexRefine::reconstruct(idx_t key, float* recons) const {
refine_index->reconstruct(key, recons);
}
size_t IndexRefine::sa_code_size() const {
return base_index->sa_code_size() + refine_index->sa_code_size();
}
void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]);
base_index->sa_encode(n, x, tmp1.get());
std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
refine_index->sa_encode(n, x, tmp2.get());
for (size_t i = 0; i < n; i++) {
uint8_t* b = bytes + i * (cs1 + cs2);
memcpy(b, tmp1.get() + cs1 * i, cs1);
memcpy(b + cs1, tmp2.get() + cs2 * i, cs2);
}
}
void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
std::unique_ptr<uint8_t[]> tmp2(
new uint8_t[n * refine_index->sa_code_size()]);
for (size_t i = 0; i < n; i++) {
memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2);
}
refine_index->sa_decode(n, tmp2.get(), x);
}
IndexRefine::~IndexRefine() {
if (own_fields)
delete base_index;
if (own_refine_index)
delete refine_index;
}
/***************************************************
* IndexRefineFlat
***************************************************/
IndexRefineFlat::IndexRefineFlat(Index* base_index)
: IndexRefine(
base_index,
new IndexFlat(base_index->d, base_index->metric_type)) {
is_trained = base_index->is_trained;
own_refine_index = true;
FAISS_THROW_IF_NOT_MSG(
base_index->ntotal == 0,
"base_index should be empty in the beginning");
}
IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
: IndexRefine(base_index, nullptr) {
is_trained = base_index->is_trained;
refine_index = new IndexFlat(base_index->d, base_index->metric_type);
own_refine_index = true;
refine_index->add(base_index->ntotal, xb);
}
IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
own_refine_index = true;
}
void IndexRefineFlat::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);
idx_t k_base = idx_t(k * k_factor);
idx_t* base_labels = labels;
float* base_distances = distances;
ScopeDeleter<idx_t> del1;
ScopeDeleter<float> del2;
if (k != k_base) {
base_labels = new idx_t[n * k_base];
del1.set(base_labels);
base_distances = new float[n * k_base];
del2.set(base_distances);
}
base_index->search(n, x, k_base, base_distances, base_labels);
for (int i = 0; i < n * k_base; i++)
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
// compute refined distances
auto rf = dynamic_cast<const IndexFlat*>(refine_index);
FAISS_THROW_IF_NOT(rf);
rf->compute_distance_subset(n, x, k_base, base_distances, base_labels);
// sort and store result
if (metric_type == METRIC_L2) {
typedef CMax<float, idx_t> C;
reorder_2_heaps<C>(
n, k, labels, distances, k_base, base_labels, base_distances);
} else if (metric_type == METRIC_INNER_PRODUCT) {
typedef CMin<float, idx_t> C;
reorder_2_heaps<C>(
n, k, labels, distances, k_base, base_labels, base_distances);
} else {
FAISS_THROW_MSG("Metric type not supported");
}
}
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/Index.h>
namespace faiss {
/** Index that queries in a base_index (a fast one) and refines the
* results with an exact search, hopefully improving the results.
*/
struct IndexRefine : Index {
/// faster index to pre-select the vectors that should be filtered
Index* base_index;
/// refinement index
Index* refine_index;
bool own_fields; ///< should the base index be deallocated?
bool own_refine_index; ///< same with the refinement index
/// factor between k requested in search and the k requested from
/// the base_index (should be >= 1)
float k_factor = 1;
/// initialize from empty index
IndexRefine(Index* base_index, Index* refine_index);
IndexRefine();
void train(idx_t n, const float* x) override;
void add(idx_t n, const float* x) override;
void reset() override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
// reconstruct is routed to the refine_index
void reconstruct(idx_t key, float* recons) const override;
/* standalone codec interface: the base_index codes are interleaved with the
* refine_index ones */
size_t sa_code_size() const override;
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
/// The sa_decode decodes from the index_refine, which is assumed to be more
/// accurate
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
~IndexRefine() override;
};
/** Version where the refinement index is an IndexFlat. It has one additional
* constructor that takes a table of elements to add to the flat refinement
* index */
struct IndexRefineFlat : IndexRefine {
explicit IndexRefineFlat(Index* base_index);
IndexRefineFlat(Index* base_index, const float* xb);
IndexRefineFlat();
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
};
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cinttypes>
#include <faiss/IndexReplicas.h>
#include <faiss/impl/FaissAssert.h>
namespace faiss {
template <typename IndexT>
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(bool threaded)
: ThreadedIndex<IndexT>(threaded) {}
template <typename IndexT>
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(idx_t d, bool threaded)
: ThreadedIndex<IndexT>(d, threaded) {}
template <typename IndexT>
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(int d, bool threaded)
: ThreadedIndex<IndexT>(d, threaded) {}
template <typename IndexT>
void IndexReplicasTemplate<IndexT>::onAfterAddIndex(IndexT* index) {
// Make sure that the parameters are the same for all prior indices, unless
// we're the first index to be added
if (this->count() > 0 && this->at(0) != index) {
auto existing = this->at(0);
FAISS_THROW_IF_NOT_FMT(
index->ntotal == existing->ntotal,
"IndexReplicas: newly added index does "
"not have same number of vectors as prior index; "
"prior index has %" PRId64 " vectors, new index has %" PRId64,
existing->ntotal,
index->ntotal);
FAISS_THROW_IF_NOT_MSG(
index->is_trained == existing->is_trained,
"IndexReplicas: newly added index does "
"not have same train status as prior index");
FAISS_THROW_IF_NOT_MSG(
index->d == existing->d,
"IndexReplicas: newly added index does "
"not have same dimension as prior index");
} else {
syncWithSubIndexes();
}
}
template <typename IndexT>
void IndexReplicasTemplate<IndexT>::onAfterRemoveIndex(IndexT* index) {
syncWithSubIndexes();
}
template <typename IndexT>
void IndexReplicasTemplate<IndexT>::train(idx_t n, const component_t* x) {
auto fn = [n, x](int i, IndexT* index) {
if (index->verbose) {
printf("begin train replica %d on %" PRId64 " points\n", i, n);
}
index->train(n, x);
if (index->verbose) {
printf("end train replica %d\n", i);
}
};
this->runOnIndex(fn);
syncWithSubIndexes();
}
template <typename IndexT>
void IndexReplicasTemplate<IndexT>::add(idx_t n, const component_t* x) {
auto fn = [n, x](int i, IndexT* index) {
if (index->verbose) {
printf("begin add replica %d on %" PRId64 " points\n", i, n);
}
index->add(n, x);
if (index->verbose) {
printf("end add replica %d\n", i);
}
};
this->runOnIndex(fn);
syncWithSubIndexes();
}
template <typename IndexT>
void IndexReplicasTemplate<IndexT>::reconstruct(idx_t n, component_t* x) const {
FAISS_THROW_IF_NOT_MSG(this->count() > 0, "no replicas in index");
// Just pass to the first replica
this->at(0)->reconstruct(n, x);
}
template <typename IndexT>
void IndexReplicasTemplate<IndexT>::search(
idx_t n,
const component_t* x,
idx_t k,
distance_t* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT_MSG(this->count() > 0, "no replicas in index");
if (n == 0) {
return;
}
auto dim = this->d;
size_t componentsPerVec = sizeof(component_t) == 1 ? (dim + 7) / 8 : dim;
// Partition the query by the number of indices we have
faiss::Index::idx_t queriesPerIndex =
(faiss::Index::idx_t)(n + this->count() - 1) /
(faiss::Index::idx_t)this->count();
FAISS_ASSERT(n / queriesPerIndex <= this->count());
auto fn = [queriesPerIndex, componentsPerVec, n, x, k, distances, labels](
int i, const IndexT* index) {
faiss::Index::idx_t base = (faiss::Index::idx_t)i * queriesPerIndex;
if (base < n) {
auto numForIndex = std::min(queriesPerIndex, n - base);
if (index->verbose) {
printf("begin search replica %d on %" PRId64 " points\n",
i,
numForIndex);
}
index->search(
numForIndex,
x + base * componentsPerVec,
k,
distances + base * k,
labels + base * k);
if (index->verbose) {
printf("end search replica %d\n", i);
}
}
};
this->runOnIndex(fn);
}
// FIXME: assumes that nothing is currently running on the sub-indexes, which is
// true with the normal API, but should use the runOnIndex API instead
template <typename IndexT>
void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
if (!this->count()) {
this->is_trained = false;
this->ntotal = 0;
return;
}
auto firstIndex = this->at(0);
this->metric_type = firstIndex->metric_type;
this->is_trained = firstIndex->is_trained;
this->ntotal = firstIndex->ntotal;
for (int i = 1; i < this->count(); ++i) {
auto index = this->at(i);
FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
FAISS_THROW_IF_NOT(this->d == index->d);
FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
FAISS_THROW_IF_NOT(this->ntotal == index->ntotal);
}
}
// No metric_type for IndexBinary
template <>
void IndexReplicasTemplate<IndexBinary>::syncWithSubIndexes() {
if (!this->count()) {
this->is_trained = false;
this->ntotal = 0;
return;
}
auto firstIndex = this->at(0);
this->is_trained = firstIndex->is_trained;
this->ntotal = firstIndex->ntotal;
for (int i = 1; i < this->count(); ++i) {
auto index = this->at(i);
FAISS_THROW_IF_NOT(this->d == index->d);
FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
FAISS_THROW_IF_NOT(this->ntotal == index->ntotal);
}
}
// explicit instantiations
template struct IndexReplicasTemplate<Index>;
template struct IndexReplicasTemplate<IndexBinary>;
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/Index.h>
#include <faiss/IndexBinary.h>
#include <faiss/impl/ThreadedIndex.h>
namespace faiss {
/// Takes individual faiss::Index instances, and splits queries for
/// sending to each Index instance, and joins the results together
/// when done.
/// Each index is managed by a separate CPU thread.
template <typename IndexT>
class IndexReplicasTemplate : public ThreadedIndex<IndexT> {
public:
using idx_t = typename IndexT::idx_t;
using component_t = typename IndexT::component_t;
using distance_t = typename IndexT::distance_t;
/// The dimension that all sub-indices must share will be the dimension of
/// the first sub-index added
/// @param threaded do we use one thread per sub-index or do queries
/// sequentially?
explicit IndexReplicasTemplate(bool threaded = true);
/// @param d the dimension that all sub-indices must share
/// @param threaded do we use one thread per sub index or do queries
/// sequentially?
explicit IndexReplicasTemplate(idx_t d, bool threaded = true);
/// int version due to the implicit bool conversion ambiguity of int as
/// dimension
explicit IndexReplicasTemplate(int d, bool threaded = true);
/// Alias for addIndex()
void add_replica(IndexT* index) {
this->addIndex(index);
}
/// Alias for removeIndex()
void remove_replica(IndexT* index) {
this->removeIndex(index);
}
/// faiss::Index API
/// All indices receive the same call
void train(idx_t n, const component_t* x) override;
/// faiss::Index API
/// All indices receive the same call
void add(idx_t n, const component_t* x) override;
/// faiss::Index API
/// Query is partitioned into a slice for each sub-index
/// split by ceil(n / #indices) for our sub-indices
void search(
idx_t n,
const component_t* x,
idx_t k,
distance_t* distances,
idx_t* labels) const override;
/// reconstructs from the first index
void reconstruct(idx_t, component_t* v) const override;
/// Synchronize the top-level index (IndexShards) with data in the
/// sub-indices
void syncWithSubIndexes();
protected:
/// Called just after an index is added
void onAfterAddIndex(IndexT* index) override;
/// Called just after an index is removed
void onAfterRemoveIndex(IndexT* index) override;
};
using IndexReplicas = IndexReplicasTemplate<Index>;
using IndexBinaryReplicas = IndexReplicasTemplate<IndexBinary>;
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexScalarQuantizer.h>
#include <algorithm>
#include <cstdio>
#include <omp.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/ScalarQuantizer.h>
#include <faiss/utils/utils.h>
namespace faiss {
/*******************************************************************
* IndexScalarQuantizer implementation
********************************************************************/
IndexScalarQuantizer::IndexScalarQuantizer(
int d,
ScalarQuantizer::QuantizerType qtype,
MetricType metric)
: IndexFlatCodes(0, d, metric), sq(d, qtype) {
is_trained = qtype == ScalarQuantizer::QT_fp16 ||
qtype == ScalarQuantizer::QT_8bit_direct;
code_size = sq.code_size;
}
IndexScalarQuantizer::IndexScalarQuantizer()
: IndexScalarQuantizer(0, ScalarQuantizer::QT_8bit) {}
void IndexScalarQuantizer::train(idx_t n, const float* x) {
sq.train(n, x);
is_trained = true;
}
void IndexScalarQuantizer::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);
FAISS_THROW_IF_NOT(
metric_type == METRIC_L2 || metric_type == METRIC_INNER_PRODUCT);
#pragma omp parallel
{
InvertedListScanner* scanner =
sq.select_InvertedListScanner(metric_type, nullptr, true);
ScopeDeleter1<InvertedListScanner> del(scanner);
scanner->list_no = 0; // directly the list number
#pragma omp for
for (idx_t i = 0; i < n; i++) {
float* D = distances + k * i;
idx_t* I = labels + k * i;
// re-order heap
if (metric_type == METRIC_L2) {
maxheap_heapify(k, D, I);
} else {
minheap_heapify(k, D, I);
}
scanner->set_query(x + i * d);
scanner->scan_codes(ntotal, codes.data(), nullptr, D, I, k);
// re-order heap
if (metric_type == METRIC_L2) {
maxheap_reorder(k, D, I);
} else {
minheap_reorder(k, D, I);
}
}
}
}
DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
ScalarQuantizer::SQDistanceComputer* dc =
sq.get_distance_computer(metric_type);
dc->code_size = sq.code_size;
dc->codes = codes.data();
return dc;
}
/* Codec interface */
void IndexScalarQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
const {
FAISS_THROW_IF_NOT(is_trained);
sq.compute_codes(x, bytes, n);
}
void IndexScalarQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x)
const {
FAISS_THROW_IF_NOT(is_trained);
sq.decode(bytes, x, n);
}
/*******************************************************************
* IndexIVFScalarQuantizer implementation
********************************************************************/
IndexIVFScalarQuantizer::IndexIVFScalarQuantizer(
Index* quantizer,
size_t d,
size_t nlist,
ScalarQuantizer::QuantizerType qtype,
MetricType metric,
bool encode_residual)
: IndexIVF(quantizer, d, nlist, 0, metric),
sq(d, qtype),
by_residual(encode_residual) {
code_size = sq.code_size;
// was not known at construction time
invlists->code_size = code_size;
is_trained = false;
}
IndexIVFScalarQuantizer::IndexIVFScalarQuantizer()
: IndexIVF(), by_residual(true) {}
void IndexIVFScalarQuantizer::train_residual(idx_t n, const float* x) {
sq.train_residual(n, x, quantizer, by_residual, verbose);
}
void IndexIVFScalarQuantizer::encode_vectors(
idx_t n,
const float* x,
const idx_t* list_nos,
uint8_t* codes,
bool include_listnos) const {
std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
size_t coarse_size = include_listnos ? coarse_code_size() : 0;
memset(codes, 0, (code_size + coarse_size) * n);
#pragma omp parallel if (n > 1000)
{
std::vector<float> residual(d);
#pragma omp for
for (idx_t i = 0; i < n; i++) {
int64_t list_no = list_nos[i];
if (list_no >= 0) {
const float* xi = x + i * d;
uint8_t* code = codes + i * (code_size + coarse_size);
if (by_residual) {
quantizer->compute_residual(xi, residual.data(), list_no);
xi = residual.data();
}
if (coarse_size) {
encode_listno(list_no, code);
}
squant->encode_vector(xi, code + coarse_size);
}
}
}
}
void IndexIVFScalarQuantizer::sa_decode(idx_t n, const uint8_t* codes, float* x)
const {
std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
size_t coarse_size = coarse_code_size();
#pragma omp parallel if (n > 1000)
{
std::vector<float> residual(d);
#pragma omp for
for (idx_t i = 0; i < n; i++) {
const uint8_t* code = codes + i * (code_size + coarse_size);
int64_t list_no = decode_listno(code);
float* xi = x + i * d;
squant->decode_vector(code + coarse_size, xi);
if (by_residual) {
quantizer->reconstruct(list_no, residual.data());
for (size_t j = 0; j < d; j++) {
xi[j] += residual[j];
}
}
}
}
}
void IndexIVFScalarQuantizer::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* coarse_idx) {
FAISS_THROW_IF_NOT(is_trained);
size_t nadd = 0;
std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
DirectMapAdd dm_add(direct_map, n, xids);
#pragma omp parallel reduction(+ : nadd)
{
std::vector<float> residual(d);
std::vector<uint8_t> one_code(code_size);
int nt = omp_get_num_threads();
int rank = omp_get_thread_num();
// each thread takes care of a subset of lists
for (size_t i = 0; i < n; i++) {
int64_t list_no = coarse_idx[i];
if (list_no >= 0 && list_no % nt == rank) {
int64_t id = xids ? xids[i] : ntotal + i;
const float* xi = x + i * d;
if (by_residual) {
quantizer->compute_residual(xi, residual.data(), list_no);
xi = residual.data();
}
memset(one_code.data(), 0, code_size);
squant->encode_vector(xi, one_code.data());
size_t ofs = invlists->add_entry(list_no, id, one_code.data());
dm_add.add(i, list_no, ofs);
nadd++;
} else if (rank == 0 && list_no == -1) {
dm_add.add(i, -1, 0);
}
}
}
ntotal += n;
}
InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner(
bool store_pairs) const {
return sq.select_InvertedListScanner(
metric_type, quantizer, store_pairs, by_residual);
}
void IndexIVFScalarQuantizer::reconstruct_from_offset(
int64_t list_no,
int64_t offset,
float* recons) const {
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());
const uint8_t* code = invlists->get_single_code(list_no, offset);
sq.decode(code, recons, 1);
for (int i = 0; i < d; ++i) {
recons[i] += centroid[i];
}
}
} // namespace faiss
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#ifndef FAISS_INDEX_SCALAR_QUANTIZER_H
#define FAISS_INDEX_SCALAR_QUANTIZER_H
#include <stdint.h>
#include <vector>
#include <faiss/IndexFlatCodes.h>
#include <faiss/IndexIVF.h>
#include <faiss/impl/ScalarQuantizer.h>
namespace faiss {
/**
* The uniform quantizer has a range [vmin, vmax]. The range can be
* the same for all dimensions (uniform) or specific per dimension
* (default).
*/
struct IndexScalarQuantizer : IndexFlatCodes {
/// Used to encode the vectors
ScalarQuantizer sq;
/** Constructor.
*
* @param d dimensionality of the input vectors
* @param M number of subquantizers
* @param nbits number of bit per subvector index
*/
IndexScalarQuantizer(
int d,
ScalarQuantizer::QuantizerType qtype,
MetricType metric = METRIC_L2);
IndexScalarQuantizer();
void train(idx_t n, const float* x) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
DistanceComputer* get_distance_computer() const override;
/* standalone codec interface */
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
};
/** An IVF implementation where the components of the residuals are
* encoded with a scalar quantizer. All distance computations
* are asymmetric, so the encoded vectors are decoded and approximate
* distances are computed.
*/
struct IndexIVFScalarQuantizer : IndexIVF {
ScalarQuantizer sq;
bool by_residual;
IndexIVFScalarQuantizer(
Index* quantizer,
size_t d,
size_t nlist,
ScalarQuantizer::QuantizerType qtype,
MetricType metric = METRIC_L2,
bool encode_residual = true);
IndexIVFScalarQuantizer();
void train_residual(idx_t n, const float* x) override;
void encode_vectors(
idx_t n,
const float* x,
const idx_t* list_nos,
uint8_t* codes,
bool include_listnos = false) const override;
void add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) override;
InvertedListScanner* get_InvertedListScanner(
bool store_pairs) const override;
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
const override;
/* standalone codec interface */
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
};
} // namespace faiss
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment