utils.py 14.3 KB
Newer Older
1
import os
2
import logging
3
4
import requests

5
from typing import List, Union
6

7
8
9
10
from apps.ollama.main import (
    generate_ollama_embeddings,
    GenerateEmbeddingsForm,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
11

12
13
from huggingface_hub import snapshot_download

14
15
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
Steven Kreitzer's avatar
Steven Kreitzer committed
16
from langchain.retrievers import (
17
    ContextualCompressionRetriever,
Steven Kreitzer's avatar
Steven Kreitzer committed
18
19
20
    EnsembleRetriever,
)

21
from typing import Optional
22

Timothy J. Baek's avatar
Timothy J. Baek committed
23
from utils.misc import get_last_user_message, add_or_update_system_message
24
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
25

26
27
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29


Timothy J. Baek's avatar
Timothy J. Baek committed
30
def query_doc(
Steven Kreitzer's avatar
Steven Kreitzer committed
31
32
    collection_name: str,
    query: str,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
    embedding_function,
34
    k: int,
Steven Kreitzer's avatar
Steven Kreitzer committed
35
):
36
    try:
Steven Kreitzer's avatar
Steven Kreitzer committed
37
        collection = CHROMA_CLIENT.get_collection(name=collection_name)
Timothy J. Baek's avatar
Timothy J. Baek committed
38
        query_embeddings = embedding_function(query)
Steven Kreitzer's avatar
Steven Kreitzer committed
39

Timothy J. Baek's avatar
Timothy J. Baek committed
40
41
42
43
        result = collection.query(
            query_embeddings=[query_embeddings],
            n_results=k,
        )
44

Timothy J. Baek's avatar
Timothy J. Baek committed
45
46
47
48
        log.info(f"query_doc:result {result}")
        return result
    except Exception as e:
        raise e
49
50


Timothy J. Baek's avatar
Timothy J. Baek committed
51
52
53
54
55
56
def query_doc_with_hybrid_search(
    collection_name: str,
    query: str,
    embedding_function,
    k: int,
    reranking_function,
tabacoWang's avatar
fix:  
tabacoWang committed
57
    r: float,
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60
61
):
    try:
        collection = CHROMA_CLIENT.get_collection(name=collection_name)
        documents = collection.get()  # get all documents
62

Timothy J. Baek's avatar
Timothy J. Baek committed
63
64
65
66
67
        bm25_retriever = BM25Retriever.from_texts(
            texts=documents.get("documents"),
            metadatas=documents.get("metadatas"),
        )
        bm25_retriever.k = k
68

Timothy J. Baek's avatar
Timothy J. Baek committed
69
70
71
72
73
        chroma_retriever = ChromaRetriever(
            collection=collection,
            embedding_function=embedding_function,
            top_n=k,
        )
74

Timothy J. Baek's avatar
Timothy J. Baek committed
75
76
77
78
79
80
        ensemble_retriever = EnsembleRetriever(
            retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
        )

        compressor = RerankCompressor(
            embedding_function=embedding_function,
Steven Kreitzer's avatar
Steven Kreitzer committed
81
            top_n=k,
Timothy J. Baek's avatar
Timothy J. Baek committed
82
83
84
85
86
87
88
            reranking_function=reranking_function,
            r_score=r,
        )

        compression_retriever = ContextualCompressionRetriever(
            base_compressor=compressor, base_retriever=ensemble_retriever
        )
89

Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
92
93
94
95
        result = compression_retriever.invoke(query)
        result = {
            "distances": [[d.metadata.get("score") for d in result]],
            "documents": [[d.page_content for d in result]],
            "metadatas": [[d.metadata for d in result]],
        }
Steven Kreitzer's avatar
Steven Kreitzer committed
96

Timothy J. Baek's avatar
Timothy J. Baek committed
97
        log.info(f"query_doc_with_hybrid_search:result {result}")
98
99
100
101
102
        return result
    except Exception as e:
        raise e


Steven Kreitzer's avatar
Steven Kreitzer committed
103
def merge_and_sort_query_results(query_results, k, reverse=False):
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
106
    # Initialize lists to store combined data
    combined_distances = []
    combined_documents = []
Steven Kreitzer's avatar
Steven Kreitzer committed
107
    combined_metadatas = []
Timothy J. Baek's avatar
Timothy J. Baek committed
108
109
110
111

    for data in query_results:
        combined_distances.extend(data["distances"][0])
        combined_documents.extend(data["documents"][0])
Steven Kreitzer's avatar
Steven Kreitzer committed
112
        combined_metadatas.extend(data["metadatas"][0])
Timothy J. Baek's avatar
Timothy J. Baek committed
113

Steven Kreitzer's avatar
Steven Kreitzer committed
114
    # Create a list of tuples (distance, document, metadata)
115
    combined = list(zip(combined_distances, combined_documents, combined_metadatas))
Timothy J. Baek's avatar
Timothy J. Baek committed
116
117

    # Sort the list based on distances
Steven Kreitzer's avatar
Steven Kreitzer committed
118
    combined.sort(key=lambda x: x[0], reverse=reverse)
Timothy J. Baek's avatar
Timothy J. Baek committed
119

120
121
122
123
124
125
126
127
    # We don't have anything :-(
    if not combined:
        sorted_distances = []
        sorted_documents = []
        sorted_metadatas = []
    else:
        # Unzip the sorted list
        sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
Timothy J. Baek's avatar
Timothy J. Baek committed
128

129
130
131
132
        # Slicing the lists to include only k elements
        sorted_distances = list(sorted_distances)[:k]
        sorted_documents = list(sorted_documents)[:k]
        sorted_metadatas = list(sorted_metadatas)[:k]
Timothy J. Baek's avatar
Timothy J. Baek committed
133
134

    # Create the output dictionary
135
    result = {
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
        "distances": [sorted_distances],
        "documents": [sorted_documents],
Steven Kreitzer's avatar
Steven Kreitzer committed
138
        "metadatas": [sorted_metadatas],
Timothy J. Baek's avatar
Timothy J. Baek committed
139
140
    }

141
    return result
Timothy J. Baek's avatar
Timothy J. Baek committed
142
143


Timothy J. Baek's avatar
Timothy J. Baek committed
144
def query_collection(
Steven Kreitzer's avatar
Steven Kreitzer committed
145
146
    collection_names: List[str],
    query: str,
Timothy J. Baek's avatar
Timothy J. Baek committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    embedding_function,
    k: int,
):
    results = []
    for collection_name in collection_names:
        try:
            result = query_doc(
                collection_name=collection_name,
                query=query,
                k=k,
                embedding_function=embedding_function,
            )
            results.append(result)
        except:
            pass
    return merge_and_sort_query_results(results, k=k)


def query_collection_with_hybrid_search(
    collection_names: List[str],
    query: str,
    embedding_function,
Steven Kreitzer's avatar
Steven Kreitzer committed
169
170
    k: int,
    reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
171
    r: float,
Timothy J. Baek's avatar
Timothy J. Baek committed
172
):
173
174
175
    results = []
    for collection_name in collection_names:
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
176
            result = query_doc_with_hybrid_search(
177
178
                collection_name=collection_name,
                query=query,
Timothy J. Baek's avatar
Timothy J. Baek committed
179
                embedding_function=embedding_function,
180
                k=k,
Steven Kreitzer's avatar
Steven Kreitzer committed
181
                reranking_function=reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
182
                r=r,
183
184
185
186
            )
            results.append(result)
        except:
            pass
Timothy J. Baek's avatar
Timothy J. Baek committed
187
    return merge_and_sort_query_results(results, k=k, reverse=True)
188
189


Timothy J. Baek's avatar
Timothy J. Baek committed
190
def rag_template(template: str, context: str, query: str):
191
192
    template = template.replace("[context]", context)
    template = template.replace("[query]", query)
Timothy J. Baek's avatar
Timothy J. Baek committed
193
    return template
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195


Timothy J. Baek's avatar
Timothy J. Baek committed
196
def get_embedding_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
197
198
199
200
201
    embedding_engine,
    embedding_model,
    embedding_function,
    openai_key,
    openai_url,
202
    batch_size,
Steven Kreitzer's avatar
Steven Kreitzer committed
203
204
205
):
    if embedding_engine == "":
        return lambda query: embedding_function.encode(query).tolist()
206
207
208
209
210
211
212
213
214
    elif embedding_engine in ["ollama", "openai"]:
        if embedding_engine == "ollama":
            func = lambda query: generate_ollama_embeddings(
                GenerateEmbeddingsForm(
                    **{
                        "model": embedding_model,
                        "prompt": query,
                    }
                )
Steven Kreitzer's avatar
Steven Kreitzer committed
215
            )
216
217
218
219
220
221
222
223
224
225
        elif embedding_engine == "openai":
            func = lambda query: generate_openai_embeddings(
                model=embedding_model,
                text=query,
                key=openai_key,
                url=openai_url,
            )

        def generate_multiple(query, f):
            if isinstance(query, list):
226
227
228
229
230
231
232
                if embedding_engine == "openai":
                    embeddings = []
                    for i in range(0, len(query), batch_size):
                        embeddings.extend(f(query[i : i + batch_size]))
                    return embeddings
                else:
                    return [f(q) for q in query]
233
234
235
236
            else:
                return f(query)

        return lambda query: generate_multiple(query, func)
Steven Kreitzer's avatar
Steven Kreitzer committed
237
238


239
240
241
242
def rag_messages(
    docs,
    messages,
    template,
Timothy J. Baek's avatar
Timothy J. Baek committed
243
    embedding_function,
244
    k,
Timothy J. Baek's avatar
Timothy J. Baek committed
245
    reranking_function,
246
    r,
Timothy J. Baek's avatar
Timothy J. Baek committed
247
    hybrid_search,
248
):
Timothy J. Baek's avatar
Timothy J. Baek committed
249
    log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
Timothy J. Baek's avatar
Timothy J. Baek committed
250
    query = get_last_user_message(messages)
Timothy J. Baek's avatar
Timothy J. Baek committed
251

252
    extracted_collections = []
Timothy J. Baek's avatar
Timothy J. Baek committed
253
254
255
256
257
    relevant_contexts = []

    for doc in docs:
        context = None

258
259
260
261
262
263
264
265
        collection_names = (
            doc["collection_names"]
            if doc["type"] == "collection"
            else [doc["collection_name"]]
        )

        collection_names = set(collection_names).difference(extracted_collections)
        if not collection_names:
266
267
            log.debug(f"skipping {doc} as it has already been extracted")
            continue
268

269
        try:
270
            if doc["type"] == "text":
271
                context = doc["content"]
Timothy J. Baek's avatar
Timothy J. Baek committed
272
            else:
Timothy J. Baek's avatar
Timothy J. Baek committed
273
274
                if hybrid_search:
                    context = query_collection_with_hybrid_search(
275
                        collection_names=collection_names,
Timothy J. Baek's avatar
Timothy J. Baek committed
276
277
278
279
280
281
282
283
                        query=query,
                        embedding_function=embedding_function,
                        k=k,
                        reranking_function=reranking_function,
                        r=r,
                    )
                else:
                    context = query_collection(
284
                        collection_names=collection_names,
Timothy J. Baek's avatar
Timothy J. Baek committed
285
286
287
288
                        query=query,
                        embedding_function=embedding_function,
                        k=k,
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
289
        except Exception as e:
290
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
291
292
            context = None

293
        if context:
294
            relevant_contexts.append({**context, "source": doc})
295

296
        extracted_collections.extend(collection_names)
Timothy J. Baek's avatar
Timothy J. Baek committed
297
298

    context_string = ""
Timothy J. Baek's avatar
Timothy J. Baek committed
299

300
    citations = []
Timothy J. Baek's avatar
Timothy J. Baek committed
301
    for context in relevant_contexts:
302
303
        try:
            if "documents" in context:
304
305
306
307
                context_string += "\n\n".join(
                    [text for text in context["documents"][0] if text is not None]
                )

308
309
310
                if "metadatas" in context:
                    citations.append(
                        {
311
                            "source": context["source"],
312
313
314
315
                            "document": context["documents"][0],
                            "metadata": context["metadatas"][0],
                        }
                    )
316
317
        except Exception as e:
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
318

319
    context_string = context_string.strip()
Timothy J. Baek's avatar
Timothy J. Baek committed
320
321
322
323
324
325
326

    ra_content = rag_template(
        template=template,
        context=context_string,
        query=query,
    )

327
    log.debug(f"ra_content: {ra_content}")
Timothy J. Baek's avatar
Timothy J. Baek committed
328
    messages = add_or_update_system_message(ra_content, messages)
Timothy J. Baek's avatar
Timothy J. Baek committed
329

330
    return messages, citations
331

Self Denial's avatar
Self Denial committed
332

333
334
335
336
337
338
339
340
341
342
343
def get_model_path(model: str, update_model: bool = False):
    # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
    cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")

    local_files_only = not update_model

    snapshot_kwargs = {
        "cache_dir": cache_dir,
        "local_files_only": local_files_only,
    }

Steven Kreitzer's avatar
Steven Kreitzer committed
344
    log.debug(f"model: {model}")
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    log.debug(f"snapshot_kwargs: {snapshot_kwargs}")

    # Inspiration from upstream sentence_transformers
    if (
        os.path.exists(model)
        or ("\\" in model or model.count("/") > 1)
        and local_files_only
    ):
        # If fully qualified path exists, return input, else set repo_id
        return model
    elif "/" not in model:
        # Set valid repo_id for model short-name
        model = "sentence-transformers" + "/" + model

    snapshot_kwargs["repo_id"] = model

    # Attempt to query the huggingface_hub library to determine the local path and/or to update
    try:
        model_repo_path = snapshot_download(**snapshot_kwargs)
        log.debug(f"model_repo_path: {model_repo_path}")
        return model_repo_path
    except Exception as e:
        log.exception(f"Cannot determine model snapshot path: {e}")
        return model


371
def generate_openai_embeddings(
372
373
374
375
    model: str,
    text: Union[str, list[str]],
    key: str,
    url: str = "https://api.openai.com/v1",
376
):
377
378
379
380
381
382
383
384
385
386
387
    if isinstance(text, list):
        embeddings = generate_openai_batch_embeddings(model, text, key, url)
    else:
        embeddings = generate_openai_batch_embeddings(model, [text], key, url)

    return embeddings[0] if isinstance(text, str) else embeddings


def generate_openai_batch_embeddings(
    model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
388
389
    try:
        r = requests.post(
Timothy J. Baek's avatar
Timothy J. Baek committed
390
            f"{url}/embeddings",
391
392
393
394
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {key}",
            },
395
            json={"input": texts, "model": model},
396
397
398
399
        )
        r.raise_for_status()
        data = r.json()
        if "data" in data:
400
            return [elem["embedding"] for elem in data["data"]]
401
402
403
404
405
        else:
            raise "Something went wrong :/"
    except Exception as e:
        print(e)
        return None
Steven Kreitzer's avatar
Steven Kreitzer committed
406
407
408
409
410


from typing import Any

from langchain_core.retrievers import BaseRetriever
411
from langchain_core.callbacks import CallbackManagerForRetrieverRun
Steven Kreitzer's avatar
Steven Kreitzer committed
412
413
414
415


class ChromaRetriever(BaseRetriever):
    collection: Any
Timothy J. Baek's avatar
Timothy J. Baek committed
416
    embedding_function: Any
417
    top_n: int
Steven Kreitzer's avatar
Steven Kreitzer committed
418
419
420
421
422
423
424

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
    ) -> List[Document]:
Timothy J. Baek's avatar
Timothy J. Baek committed
425
        query_embeddings = self.embedding_function(query)
Steven Kreitzer's avatar
Steven Kreitzer committed
426
427
428

        results = self.collection.query(
            query_embeddings=[query_embeddings],
429
            n_results=self.top_n,
Steven Kreitzer's avatar
Steven Kreitzer committed
430
431
432
433
434
435
        )

        ids = results["ids"][0]
        metadatas = results["metadatas"][0]
        documents = results["documents"][0]

Steven Kreitzer's avatar
Steven Kreitzer committed
436
437
438
439
440
441
442
        results = []
        for idx in range(len(ids)):
            results.append(
                Document(
                    metadata=metadatas[idx],
                    page_content=documents[idx],
                )
Steven Kreitzer's avatar
Steven Kreitzer committed
443
            )
Steven Kreitzer's avatar
Steven Kreitzer committed
444
        return results
445
446
447
448
449
450
451
452
453
454
455
456
457
458


import operator

from typing import Optional, Sequence

from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra

from sentence_transformers import util


class RerankCompressor(BaseDocumentCompressor):
Timothy J. Baek's avatar
Timothy J. Baek committed
459
    embedding_function: Any
Steven Kreitzer's avatar
Steven Kreitzer committed
460
    top_n: int
461
462
463
464
465
466
467
468
469
470
471
472
473
    reranking_function: Any
    r_score: float

    class Config:
        extra = Extra.forbid
        arbitrary_types_allowed = True

    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
Steven Kreitzer's avatar
Steven Kreitzer committed
474
475
476
        reranking = self.reranking_function is not None

        if reranking:
477
478
479
480
            scores = self.reranking_function.predict(
                [(query, doc.page_content) for doc in documents]
            )
        else:
Timothy J. Baek's avatar
Timothy J. Baek committed
481
482
            query_embedding = self.embedding_function(query)
            document_embedding = self.embedding_function(
483
484
485
486
487
488
489
490
491
492
                [doc.page_content for doc in documents]
            )
            scores = util.cos_sim(query_embedding, document_embedding)[0]

        docs_with_scores = list(zip(documents, scores.tolist()))
        if self.r_score:
            docs_with_scores = [
                (d, s) for d, s in docs_with_scores if s >= self.r_score
            ]

Steven Kreitzer's avatar
Steven Kreitzer committed
493
        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
494
495
496
497
498
499
500
501
502
503
        final_results = []
        for doc, doc_score in result[: self.top_n]:
            metadata = doc.metadata
            metadata["score"] = doc_score
            doc = Document(
                page_content=doc.page_content,
                metadata=metadata,
            )
            final_results.append(doc)
        return final_results