test_pq_encoding.cpp 2.46 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <iostream>
#include <memory>
#include <vector>

#include <gtest/gtest.h>

#include <faiss/impl/ProductQuantizer.h>

namespace {

const std::vector<uint64_t> random_vector(size_t s) {
    std::vector<uint64_t> v(s, 0);
    for (size_t i = 0; i < s; ++i) {
        v[i] = rand();
    }

    return v;
}

} // namespace

TEST(PQEncoderGeneric, encode) {
    const int nsubcodes = 97;
    const int minbits = 1;
    const int maxbits = 24;
    const std::vector<uint64_t> values = random_vector(nsubcodes);

    for (int nbits = minbits; nbits <= maxbits; ++nbits) {
        std::cerr << "nbits = " << nbits << std::endl;

        const uint64_t mask = (1ull << nbits) - 1;
        std::unique_ptr<uint8_t[]> codes(
                new uint8_t[(nsubcodes * maxbits + 7) / 8]);

        // NOTE(hoss): Necessary scope to ensure trailing bits are flushed to
        // mem.
        {
            faiss::PQEncoderGeneric encoder(codes.get(), nbits);
            for (const auto& v : values) {
                encoder.encode(v & mask);
            }
        }

        faiss::PQDecoderGeneric decoder(codes.get(), nbits);
        for (int i = 0; i < nsubcodes; ++i) {
            uint64_t v = decoder.decode();
            EXPECT_EQ(values[i] & mask, v);
        }
    }
}

TEST(PQEncoder8, encode) {
    const int nsubcodes = 100;
    const std::vector<uint64_t> values = random_vector(nsubcodes);
    const uint64_t mask = 0xFF;
    std::unique_ptr<uint8_t[]> codes(new uint8_t[nsubcodes]);

    faiss::PQEncoder8 encoder(codes.get(), 8);
    for (const auto& v : values) {
        encoder.encode(v & mask);
    }

    faiss::PQDecoder8 decoder(codes.get(), 8);
    for (int i = 0; i < nsubcodes; ++i) {
        uint64_t v = decoder.decode();
        EXPECT_EQ(values[i] & mask, v);
    }
}

TEST(PQEncoder16, encode) {
    const int nsubcodes = 100;
    const std::vector<uint64_t> values = random_vector(nsubcodes);
    const uint64_t mask = 0xFFFF;
    std::unique_ptr<uint8_t[]> codes(new uint8_t[2 * nsubcodes]);

    faiss::PQEncoder16 encoder(codes.get(), 16);
    for (const auto& v : values) {
        encoder.encode(v & mask);
    }

    faiss::PQDecoder16 decoder(codes.get(), 16);
    for (int i = 0; i < nsubcodes; ++i) {
        uint64_t v = decoder.decode();
        EXPECT_EQ(values[i] & mask, v);
    }
}