utils.py 14.1 KB
Newer Older
1
import os
2
import logging
3
4
import requests

5
from typing import List, Union
6

7
8
9
10
from apps.ollama.main import (
    generate_ollama_embeddings,
    GenerateEmbeddingsForm,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
11

12
13
from huggingface_hub import snapshot_download

14
15
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
Steven Kreitzer's avatar
Steven Kreitzer committed
16
from langchain.retrievers import (
17
    ContextualCompressionRetriever,
Steven Kreitzer's avatar
Steven Kreitzer committed
18
19
20
    EnsembleRetriever,
)

21
from typing import Optional
22

Timothy J. Baek's avatar
Timothy J. Baek committed
23
from utils.misc import get_last_user_message, add_or_update_system_message
24
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
25

26
27
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29


Timothy J. Baek's avatar
Timothy J. Baek committed
30
def query_doc(
Steven Kreitzer's avatar
Steven Kreitzer committed
31
32
    collection_name: str,
    query: str,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
    embedding_function,
34
    k: int,
Steven Kreitzer's avatar
Steven Kreitzer committed
35
):
36
    try:
Steven Kreitzer's avatar
Steven Kreitzer committed
37
        collection = CHROMA_CLIENT.get_collection(name=collection_name)
Timothy J. Baek's avatar
Timothy J. Baek committed
38
        query_embeddings = embedding_function(query)
Steven Kreitzer's avatar
Steven Kreitzer committed
39

Timothy J. Baek's avatar
Timothy J. Baek committed
40
41
42
43
        result = collection.query(
            query_embeddings=[query_embeddings],
            n_results=k,
        )
44

Timothy J. Baek's avatar
Timothy J. Baek committed
45
46
47
48
        log.info(f"query_doc:result {result}")
        return result
    except Exception as e:
        raise e
49
50


Timothy J. Baek's avatar
Timothy J. Baek committed
51
52
53
54
55
56
def query_doc_with_hybrid_search(
    collection_name: str,
    query: str,
    embedding_function,
    k: int,
    reranking_function,
tabacoWang's avatar
fix:  
tabacoWang committed
57
    r: float,
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60
61
):
    try:
        collection = CHROMA_CLIENT.get_collection(name=collection_name)
        documents = collection.get()  # get all documents
62

Timothy J. Baek's avatar
Timothy J. Baek committed
63
64
65
66
67
        bm25_retriever = BM25Retriever.from_texts(
            texts=documents.get("documents"),
            metadatas=documents.get("metadatas"),
        )
        bm25_retriever.k = k
68

Timothy J. Baek's avatar
Timothy J. Baek committed
69
70
71
72
73
        chroma_retriever = ChromaRetriever(
            collection=collection,
            embedding_function=embedding_function,
            top_n=k,
        )
74

Timothy J. Baek's avatar
Timothy J. Baek committed
75
76
77
78
79
80
        ensemble_retriever = EnsembleRetriever(
            retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
        )

        compressor = RerankCompressor(
            embedding_function=embedding_function,
Steven Kreitzer's avatar
Steven Kreitzer committed
81
            top_n=k,
Timothy J. Baek's avatar
Timothy J. Baek committed
82
83
84
85
86
87
88
            reranking_function=reranking_function,
            r_score=r,
        )

        compression_retriever = ContextualCompressionRetriever(
            base_compressor=compressor, base_retriever=ensemble_retriever
        )
89

Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
92
93
94
95
        result = compression_retriever.invoke(query)
        result = {
            "distances": [[d.metadata.get("score") for d in result]],
            "documents": [[d.page_content for d in result]],
            "metadatas": [[d.metadata for d in result]],
        }
Steven Kreitzer's avatar
Steven Kreitzer committed
96

Timothy J. Baek's avatar
Timothy J. Baek committed
97
        log.info(f"query_doc_with_hybrid_search:result {result}")
98
99
100
101
102
        return result
    except Exception as e:
        raise e


Steven Kreitzer's avatar
Steven Kreitzer committed
103
def merge_and_sort_query_results(query_results, k, reverse=False):
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
106
    # Initialize lists to store combined data
    combined_distances = []
    combined_documents = []
Steven Kreitzer's avatar
Steven Kreitzer committed
107
    combined_metadatas = []
Timothy J. Baek's avatar
Timothy J. Baek committed
108
109
110
111

    for data in query_results:
        combined_distances.extend(data["distances"][0])
        combined_documents.extend(data["documents"][0])
Steven Kreitzer's avatar
Steven Kreitzer committed
112
        combined_metadatas.extend(data["metadatas"][0])
Timothy J. Baek's avatar
Timothy J. Baek committed
113

Steven Kreitzer's avatar
Steven Kreitzer committed
114
    # Create a list of tuples (distance, document, metadata)
115
    combined = list(zip(combined_distances, combined_documents, combined_metadatas))
Timothy J. Baek's avatar
Timothy J. Baek committed
116
117

    # Sort the list based on distances
Steven Kreitzer's avatar
Steven Kreitzer committed
118
    combined.sort(key=lambda x: x[0], reverse=reverse)
Timothy J. Baek's avatar
Timothy J. Baek committed
119

120
121
122
123
124
125
126
127
    # We don't have anything :-(
    if not combined:
        sorted_distances = []
        sorted_documents = []
        sorted_metadatas = []
    else:
        # Unzip the sorted list
        sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
Timothy J. Baek's avatar
Timothy J. Baek committed
128

129
130
131
132
        # Slicing the lists to include only k elements
        sorted_distances = list(sorted_distances)[:k]
        sorted_documents = list(sorted_documents)[:k]
        sorted_metadatas = list(sorted_metadatas)[:k]
Timothy J. Baek's avatar
Timothy J. Baek committed
133
134

    # Create the output dictionary
135
    result = {
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
        "distances": [sorted_distances],
        "documents": [sorted_documents],
Steven Kreitzer's avatar
Steven Kreitzer committed
138
        "metadatas": [sorted_metadatas],
Timothy J. Baek's avatar
Timothy J. Baek committed
139
140
    }

141
    return result
Timothy J. Baek's avatar
Timothy J. Baek committed
142
143


Timothy J. Baek's avatar
Timothy J. Baek committed
144
def query_collection(
Steven Kreitzer's avatar
Steven Kreitzer committed
145
146
    collection_names: List[str],
    query: str,
Timothy J. Baek's avatar
Timothy J. Baek committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    embedding_function,
    k: int,
):
    results = []
    for collection_name in collection_names:
        try:
            result = query_doc(
                collection_name=collection_name,
                query=query,
                k=k,
                embedding_function=embedding_function,
            )
            results.append(result)
        except:
            pass
    return merge_and_sort_query_results(results, k=k)


def query_collection_with_hybrid_search(
    collection_names: List[str],
    query: str,
    embedding_function,
Steven Kreitzer's avatar
Steven Kreitzer committed
169
170
    k: int,
    reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
171
    r: float,
Timothy J. Baek's avatar
Timothy J. Baek committed
172
):
173
174
175
    results = []
    for collection_name in collection_names:
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
176
            result = query_doc_with_hybrid_search(
177
178
                collection_name=collection_name,
                query=query,
Timothy J. Baek's avatar
Timothy J. Baek committed
179
                embedding_function=embedding_function,
180
                k=k,
Steven Kreitzer's avatar
Steven Kreitzer committed
181
                reranking_function=reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
182
                r=r,
183
184
185
186
            )
            results.append(result)
        except:
            pass
Timothy J. Baek's avatar
Timothy J. Baek committed
187
    return merge_and_sort_query_results(results, k=k, reverse=True)
188
189


Timothy J. Baek's avatar
Timothy J. Baek committed
190
def rag_template(template: str, context: str, query: str):
191
192
    template = template.replace("[context]", context)
    template = template.replace("[query]", query)
Timothy J. Baek's avatar
Timothy J. Baek committed
193
    return template
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195


Timothy J. Baek's avatar
Timothy J. Baek committed
196
def get_embedding_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
197
198
199
200
201
    embedding_engine,
    embedding_model,
    embedding_function,
    openai_key,
    openai_url,
202
    batch_size,
Steven Kreitzer's avatar
Steven Kreitzer committed
203
204
205
):
    if embedding_engine == "":
        return lambda query: embedding_function.encode(query).tolist()
206
207
208
209
210
211
212
213
214
    elif embedding_engine in ["ollama", "openai"]:
        if embedding_engine == "ollama":
            func = lambda query: generate_ollama_embeddings(
                GenerateEmbeddingsForm(
                    **{
                        "model": embedding_model,
                        "prompt": query,
                    }
                )
Steven Kreitzer's avatar
Steven Kreitzer committed
215
            )
216
217
218
219
220
221
222
223
224
225
        elif embedding_engine == "openai":
            func = lambda query: generate_openai_embeddings(
                model=embedding_model,
                text=query,
                key=openai_key,
                url=openai_url,
            )

        def generate_multiple(query, f):
            if isinstance(query, list):
226
227
228
229
230
231
232
                if embedding_engine == "openai":
                    embeddings = []
                    for i in range(0, len(query), batch_size):
                        embeddings.extend(f(query[i : i + batch_size]))
                    return embeddings
                else:
                    return [f(q) for q in query]
233
234
235
236
            else:
                return f(query)

        return lambda query: generate_multiple(query, func)
Steven Kreitzer's avatar
Steven Kreitzer committed
237
238


Timothy J. Baek's avatar
Timothy J. Baek committed
239
def get_rag_context(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
240
    files,
241
    messages,
Timothy J. Baek's avatar
Timothy J. Baek committed
242
    embedding_function,
243
    k,
Timothy J. Baek's avatar
Timothy J. Baek committed
244
    reranking_function,
245
    r,
Timothy J. Baek's avatar
Timothy J. Baek committed
246
    hybrid_search,
247
):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
248
    log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
Timothy J. Baek's avatar
Timothy J. Baek committed
249
    query = get_last_user_message(messages)
Timothy J. Baek's avatar
Timothy J. Baek committed
250

251
    extracted_collections = []
Timothy J. Baek's avatar
Timothy J. Baek committed
252
253
    relevant_contexts = []

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
254
    for file in files:
Timothy J. Baek's avatar
Timothy J. Baek committed
255
256
        context = None

257
        collection_names = (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
258
259
260
            file["collection_names"]
            if file["type"] == "collection"
            else [file["collection_name"]]
261
262
263
264
        )

        collection_names = set(collection_names).difference(extracted_collections)
        if not collection_names:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
265
            log.debug(f"skipping {file} as it has already been extracted")
266
            continue
267

268
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
269
270
            if file["type"] == "text":
                context = file["content"]
Timothy J. Baek's avatar
Timothy J. Baek committed
271
            else:
Timothy J. Baek's avatar
Timothy J. Baek committed
272
273
                if hybrid_search:
                    context = query_collection_with_hybrid_search(
274
                        collection_names=collection_names,
Timothy J. Baek's avatar
Timothy J. Baek committed
275
276
277
278
279
280
281
282
                        query=query,
                        embedding_function=embedding_function,
                        k=k,
                        reranking_function=reranking_function,
                        r=r,
                    )
                else:
                    context = query_collection(
283
                        collection_names=collection_names,
Timothy J. Baek's avatar
Timothy J. Baek committed
284
285
286
287
                        query=query,
                        embedding_function=embedding_function,
                        k=k,
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
288
        except Exception as e:
289
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
290
291
            context = None

292
        if context:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
293
            relevant_contexts.append({**context, "source": file})
294

295
        extracted_collections.extend(collection_names)
Timothy J. Baek's avatar
Timothy J. Baek committed
296

297
    contexts = []
298
    citations = []
299

Timothy J. Baek's avatar
Timothy J. Baek committed
300
    for context in relevant_contexts:
301
302
        try:
            if "documents" in context:
303
304
305
306
                contexts.append(
                    "\n\n".join(
                        [text for text in context["documents"][0] if text is not None]
                    )
307
308
                )

309
310
311
                if "metadatas" in context:
                    citations.append(
                        {
312
                            "source": context["source"],
313
314
315
316
                            "document": context["documents"][0],
                            "metadata": context["metadatas"][0],
                        }
                    )
317
318
        except Exception as e:
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
319

320
    return contexts, citations
321

Self Denial's avatar
Self Denial committed
322

323
324
325
326
327
328
329
330
331
332
333
def get_model_path(model: str, update_model: bool = False):
    # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
    cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")

    local_files_only = not update_model

    snapshot_kwargs = {
        "cache_dir": cache_dir,
        "local_files_only": local_files_only,
    }

Steven Kreitzer's avatar
Steven Kreitzer committed
334
    log.debug(f"model: {model}")
335
336
337
338
    log.debug(f"snapshot_kwargs: {snapshot_kwargs}")

    # Inspiration from upstream sentence_transformers
    if (
Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
339
        os.path.exists(model)
340
341
342
343
        or ("\\" in model or model.count("/") > 1)
        and local_files_only
    ):
        # If fully qualified path exists, return input, else set repo_id
Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
344
        return model
345
346
347
348
349
350
351
352
353
354
355
356
357
    elif "/" not in model:
        # Set valid repo_id for model short-name
        model = "sentence-transformers" + "/" + model

    snapshot_kwargs["repo_id"] = model

    # Attempt to query the huggingface_hub library to determine the local path and/or to update
    try:
        model_repo_path = snapshot_download(**snapshot_kwargs)
        log.debug(f"model_repo_path: {model_repo_path}")
        return model_repo_path
    except Exception as e:
        log.exception(f"Cannot determine model snapshot path: {e}")
Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
358
        return model
359
360


361
def generate_openai_embeddings(
362
363
364
365
    model: str,
    text: Union[str, list[str]],
    key: str,
    url: str = "https://api.openai.com/v1",
366
):
367
368
369
370
371
372
373
374
375
376
377
    if isinstance(text, list):
        embeddings = generate_openai_batch_embeddings(model, text, key, url)
    else:
        embeddings = generate_openai_batch_embeddings(model, [text], key, url)

    return embeddings[0] if isinstance(text, str) else embeddings


def generate_openai_batch_embeddings(
    model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
378
379
    try:
        r = requests.post(
Timothy J. Baek's avatar
Timothy J. Baek committed
380
            f"{url}/embeddings",
381
382
383
384
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {key}",
            },
385
            json={"input": texts, "model": model},
386
387
388
389
        )
        r.raise_for_status()
        data = r.json()
        if "data" in data:
390
            return [elem["embedding"] for elem in data["data"]]
391
392
393
394
395
        else:
            raise "Something went wrong :/"
    except Exception as e:
        print(e)
        return None
Steven Kreitzer's avatar
Steven Kreitzer committed
396
397
398
399
400


from typing import Any

from langchain_core.retrievers import BaseRetriever
401
from langchain_core.callbacks import CallbackManagerForRetrieverRun
Steven Kreitzer's avatar
Steven Kreitzer committed
402
403
404
405


class ChromaRetriever(BaseRetriever):
    collection: Any
Timothy J. Baek's avatar
Timothy J. Baek committed
406
    embedding_function: Any
407
    top_n: int
Steven Kreitzer's avatar
Steven Kreitzer committed
408
409
410
411
412
413
414

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
    ) -> List[Document]:
Timothy J. Baek's avatar
Timothy J. Baek committed
415
        query_embeddings = self.embedding_function(query)
Steven Kreitzer's avatar
Steven Kreitzer committed
416
417
418

        results = self.collection.query(
            query_embeddings=[query_embeddings],
419
            n_results=self.top_n,
Steven Kreitzer's avatar
Steven Kreitzer committed
420
421
422
423
424
425
        )

        ids = results["ids"][0]
        metadatas = results["metadatas"][0]
        documents = results["documents"][0]

Steven Kreitzer's avatar
Steven Kreitzer committed
426
427
428
429
430
431
432
        results = []
        for idx in range(len(ids)):
            results.append(
                Document(
                    metadata=metadatas[idx],
                    page_content=documents[idx],
                )
Steven Kreitzer's avatar
Steven Kreitzer committed
433
            )
Steven Kreitzer's avatar
Steven Kreitzer committed
434
        return results
435
436
437
438
439
440
441
442
443
444
445
446


import operator

from typing import Optional, Sequence

from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra


class RerankCompressor(BaseDocumentCompressor):
Timothy J. Baek's avatar
Timothy J. Baek committed
447
    embedding_function: Any
Steven Kreitzer's avatar
Steven Kreitzer committed
448
    top_n: int
449
450
451
452
453
454
455
456
457
458
459
460
461
    reranking_function: Any
    r_score: float

    class Config:
        extra = Extra.forbid
        arbitrary_types_allowed = True

    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
Steven Kreitzer's avatar
Steven Kreitzer committed
462
463
464
        reranking = self.reranking_function is not None

        if reranking:
465
466
467
468
            scores = self.reranking_function.predict(
                [(query, doc.page_content) for doc in documents]
            )
        else:
469
470
            from sentence_transformers import util

Timothy J. Baek's avatar
Timothy J. Baek committed
471
472
            query_embedding = self.embedding_function(query)
            document_embedding = self.embedding_function(
473
474
475
476
477
478
479
480
481
482
                [doc.page_content for doc in documents]
            )
            scores = util.cos_sim(query_embedding, document_embedding)[0]

        docs_with_scores = list(zip(documents, scores.tolist()))
        if self.r_score:
            docs_with_scores = [
                (d, s) for d, s in docs_with_scores if s >= self.r_score
            ]

Steven Kreitzer's avatar
Steven Kreitzer committed
483
        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
484
485
486
487
488
489
490
491
492
493
        final_results = []
        for doc, doc_score in result[: self.top_n]:
            metadata = doc.metadata
            metadata["score"] = doc_score
            doc = Document(
                page_content=doc.page_content,
                metadata=metadata,
            )
            final_results.append(doc)
        return final_results