semantic_search_recommended.py 4.77 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
This script showcases a recommended approach to perform semantic search using quantized embeddings with FAISS and usearch.
In particular, it uses binary search with int8 rescoring. The binary search is highly efficient, and its index can be kept
in memory even for massive datasets: it takes (num_dimensions * num_documents / 8) bytes, i.e. 1.19GB for 10 million embeddings.
"""

import json
import os
import time

import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
from datasets import load_dataset

import faiss
from usearch.index import Index
# We use usearch as it can efficiently load int8 vectors from disk.

# Load the model
# NOTE: Because we are only comparing questions here, we will use the "query" prompt for everything.
# Normally you don't use this prompt for documents, but only for the queries
model = SentenceTransformer(
    "mixedbread-ai/mxbai-embed-large-v1",
    prompts={"query": "Represent this sentence for searching relevant passages: "},
    default_prompt_name="query",
)

# Load a corpus with texts
dataset = load_dataset("quora", split="train").map(
    lambda batch: {"text": [text for sample in batch["questions"] for text in sample["text"]]},
    batched=True,
    remove_columns=["questions", "is_duplicate"],
)
max_corpus_size = 100_000
corpus = dataset["text"][:max_corpus_size]

# Apply some default query
query = "How do I become a good programmer?"

# Try to load the precomputed binary and int8 indices
if os.path.exists("quora_faiss_ubinary.index"):
    binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("quora_faiss_ubinary.index")
    int8_view = Index.restore("quora_usearch_int8.index", view=True)

else:
    # Encode the corpus using the full precision
    full_corpus_embeddings = model.encode(corpus, normalize_embeddings=True, show_progress_bar=True)

    # Convert the embeddings to "ubinary" for efficient FAISS search
    ubinary_embeddings = quantize_embeddings(full_corpus_embeddings, "ubinary")
    binary_index = faiss.IndexBinaryFlat(1024)
    binary_index.add(ubinary_embeddings)
    faiss.write_index_binary(binary_index, "quora_faiss_ubinary.index")

    # Convert the embeddings to "int8" for efficiently loading int8 indices with usearch
    int8_embeddings = quantize_embeddings(full_corpus_embeddings, "int8")
    index = Index(ndim=1024, metric="ip", dtype="i8")
    index.add(np.arange(len(int8_embeddings)), int8_embeddings)
    index.save("quora_usearch_int8.index")
    del index

    # Load the int8 index as a view, which does not cost any memory
    int8_view = Index.restore("quora_usearch_int8.index", view=True)


def search(query, top_k: int = 10, rescore_multiplier: int = 4):
    # 1. Embed the query as float32
    start_time = time.time()
    query_embedding = model.encode(query)
    embed_time = time.time() - start_time

    # 2. Quantize the query to ubinary
    start_time = time.time()
    query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
    quantize_time = time.time() - start_time

    # 3. Search the binary index
    start_time = time.time()
    _scores, binary_ids = binary_index.search(query_embedding_ubinary, top_k * rescore_multiplier)
    binary_ids = binary_ids[0]
    search_time = time.time() - start_time

    # 4. Load the corresponding int8 embeddings
    start_time = time.time()
    int8_embeddings = int8_view[binary_ids].astype(int)
    load_time = time.time() - start_time

    # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
    start_time = time.time()
    scores = query_embedding @ int8_embeddings.T
    rescore_time = time.time() - start_time

    # 6. Sort the scores and return the top_k
    start_time = time.time()
    indices = (-scores).argsort()[:top_k]
    top_k_indices = binary_ids[indices]
    top_k_scores = scores[indices]
    sort_time = time.time() - start_time

    return (
        top_k_scores.tolist(),
        top_k_indices.tolist(),
        {
            "Embed Time": f"{embed_time:.4f} s",
            "Quantize Time": f"{quantize_time:.4f} s",
            "Search Time": f"{search_time:.4f} s",
            "Load Time": f"{load_time:.4f} s",
            "Rescore Time": f"{rescore_time:.4f} s",
            "Sort Time": f"{sort_time:.4f} s",
            "Total Retrieval Time": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s",
        },
    )


while True:
    scores, indices, timings = search(query)

    # Output the results
    print(f"Timings:\n{json.dumps(timings, indent=2)}")
    print(f"Query: {query}")
    for score, index in zip(scores, indices):
        print(f"(Score: {score:.4f}) {corpus[index]}")
    print("")

    # 10. Prompt for more queries
    query = input("Please enter a question: ")