utils.py 15.5 KB
Newer Older
1
import os
2
import logging
3
4
import requests

5
from typing import List, Union
6

7
8
9
10
from apps.ollama.main import (
    generate_ollama_embeddings,
    GenerateEmbeddingsForm,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
11

12
13
from huggingface_hub import snapshot_download

14
15
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
Steven Kreitzer's avatar
Steven Kreitzer committed
16
from langchain.retrievers import (
17
    ContextualCompressionRetriever,
Steven Kreitzer's avatar
Steven Kreitzer committed
18
19
20
    EnsembleRetriever,
)

21
from typing import Optional
22

Timothy J. Baek's avatar
Timothy J. Baek committed
23

24
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
25

26
27
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29


Timothy J. Baek's avatar
Timothy J. Baek committed
30
def query_doc(
Steven Kreitzer's avatar
Steven Kreitzer committed
31
32
    collection_name: str,
    query: str,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
    embedding_function,
34
    k: int,
Steven Kreitzer's avatar
Steven Kreitzer committed
35
):
36
    try:
Steven Kreitzer's avatar
Steven Kreitzer committed
37
        collection = CHROMA_CLIENT.get_collection(name=collection_name)
Timothy J. Baek's avatar
Timothy J. Baek committed
38
        query_embeddings = embedding_function(query)
Steven Kreitzer's avatar
Steven Kreitzer committed
39

Timothy J. Baek's avatar
Timothy J. Baek committed
40
41
42
43
        result = collection.query(
            query_embeddings=[query_embeddings],
            n_results=k,
        )
44

Timothy J. Baek's avatar
Timothy J. Baek committed
45
46
47
48
        log.info(f"query_doc:result {result}")
        return result
    except Exception as e:
        raise e
49
50


Timothy J. Baek's avatar
Timothy J. Baek committed
51
52
53
54
55
56
def query_doc_with_hybrid_search(
    collection_name: str,
    query: str,
    embedding_function,
    k: int,
    reranking_function,
tabacoWang's avatar
fix:  
tabacoWang committed
57
    r: float,
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60
61
):
    try:
        collection = CHROMA_CLIENT.get_collection(name=collection_name)
        documents = collection.get()  # get all documents
62

Timothy J. Baek's avatar
Timothy J. Baek committed
63
64
65
66
67
        bm25_retriever = BM25Retriever.from_texts(
            texts=documents.get("documents"),
            metadatas=documents.get("metadatas"),
        )
        bm25_retriever.k = k
68

Timothy J. Baek's avatar
Timothy J. Baek committed
69
70
71
72
73
        chroma_retriever = ChromaRetriever(
            collection=collection,
            embedding_function=embedding_function,
            top_n=k,
        )
74

Timothy J. Baek's avatar
Timothy J. Baek committed
75
76
77
78
79
80
        ensemble_retriever = EnsembleRetriever(
            retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
        )

        compressor = RerankCompressor(
            embedding_function=embedding_function,
Steven Kreitzer's avatar
Steven Kreitzer committed
81
            top_n=k,
Timothy J. Baek's avatar
Timothy J. Baek committed
82
83
84
85
86
87
88
            reranking_function=reranking_function,
            r_score=r,
        )

        compression_retriever = ContextualCompressionRetriever(
            base_compressor=compressor, base_retriever=ensemble_retriever
        )
89

Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
92
93
94
95
        result = compression_retriever.invoke(query)
        result = {
            "distances": [[d.metadata.get("score") for d in result]],
            "documents": [[d.page_content for d in result]],
            "metadatas": [[d.metadata for d in result]],
        }
Steven Kreitzer's avatar
Steven Kreitzer committed
96

Timothy J. Baek's avatar
Timothy J. Baek committed
97
        log.info(f"query_doc_with_hybrid_search:result {result}")
98
99
100
101
102
        return result
    except Exception as e:
        raise e


Steven Kreitzer's avatar
Steven Kreitzer committed
103
def merge_and_sort_query_results(query_results, k, reverse=False):
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
106
    # Initialize lists to store combined data
    combined_distances = []
    combined_documents = []
Steven Kreitzer's avatar
Steven Kreitzer committed
107
    combined_metadatas = []
Timothy J. Baek's avatar
Timothy J. Baek committed
108
109
110
111

    for data in query_results:
        combined_distances.extend(data["distances"][0])
        combined_documents.extend(data["documents"][0])
Steven Kreitzer's avatar
Steven Kreitzer committed
112
        combined_metadatas.extend(data["metadatas"][0])
Timothy J. Baek's avatar
Timothy J. Baek committed
113

Steven Kreitzer's avatar
Steven Kreitzer committed
114
    # Create a list of tuples (distance, document, metadata)
115
    combined = list(zip(combined_distances, combined_documents, combined_metadatas))
Timothy J. Baek's avatar
Timothy J. Baek committed
116
117

    # Sort the list based on distances
Steven Kreitzer's avatar
Steven Kreitzer committed
118
    combined.sort(key=lambda x: x[0], reverse=reverse)
Timothy J. Baek's avatar
Timothy J. Baek committed
119

120
121
122
123
124
125
126
127
    # We don't have anything :-(
    if not combined:
        sorted_distances = []
        sorted_documents = []
        sorted_metadatas = []
    else:
        # Unzip the sorted list
        sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
Timothy J. Baek's avatar
Timothy J. Baek committed
128

129
130
131
132
        # Slicing the lists to include only k elements
        sorted_distances = list(sorted_distances)[:k]
        sorted_documents = list(sorted_documents)[:k]
        sorted_metadatas = list(sorted_metadatas)[:k]
Timothy J. Baek's avatar
Timothy J. Baek committed
133
134

    # Create the output dictionary
135
    result = {
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
        "distances": [sorted_distances],
        "documents": [sorted_documents],
Steven Kreitzer's avatar
Steven Kreitzer committed
138
        "metadatas": [sorted_metadatas],
Timothy J. Baek's avatar
Timothy J. Baek committed
139
140
    }

141
    return result
Timothy J. Baek's avatar
Timothy J. Baek committed
142
143


Timothy J. Baek's avatar
Timothy J. Baek committed
144
def query_collection(
Steven Kreitzer's avatar
Steven Kreitzer committed
145
146
    collection_names: List[str],
    query: str,
Timothy J. Baek's avatar
Timothy J. Baek committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    embedding_function,
    k: int,
):
    results = []
    for collection_name in collection_names:
        try:
            result = query_doc(
                collection_name=collection_name,
                query=query,
                k=k,
                embedding_function=embedding_function,
            )
            results.append(result)
        except:
            pass
    return merge_and_sort_query_results(results, k=k)


def query_collection_with_hybrid_search(
    collection_names: List[str],
    query: str,
    embedding_function,
Steven Kreitzer's avatar
Steven Kreitzer committed
169
170
    k: int,
    reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
171
    r: float,
Timothy J. Baek's avatar
Timothy J. Baek committed
172
):
173
174
175
    results = []
    for collection_name in collection_names:
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
176
            result = query_doc_with_hybrid_search(
177
178
                collection_name=collection_name,
                query=query,
Timothy J. Baek's avatar
Timothy J. Baek committed
179
                embedding_function=embedding_function,
180
                k=k,
Steven Kreitzer's avatar
Steven Kreitzer committed
181
                reranking_function=reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
182
                r=r,
183
184
185
186
            )
            results.append(result)
        except:
            pass
Timothy J. Baek's avatar
Timothy J. Baek committed
187
    return merge_and_sort_query_results(results, k=k, reverse=True)
188
189


Timothy J. Baek's avatar
Timothy J. Baek committed
190
def rag_template(template: str, context: str, query: str):
191
192
    template = template.replace("[context]", context)
    template = template.replace("[query]", query)
Timothy J. Baek's avatar
Timothy J. Baek committed
193
    return template
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195


Timothy J. Baek's avatar
Timothy J. Baek committed
196
def get_embedding_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
197
198
199
200
201
    embedding_engine,
    embedding_model,
    embedding_function,
    openai_key,
    openai_url,
202
    batch_size,
Steven Kreitzer's avatar
Steven Kreitzer committed
203
204
205
):
    if embedding_engine == "":
        return lambda query: embedding_function.encode(query).tolist()
206
207
208
209
210
211
212
213
214
    elif embedding_engine in ["ollama", "openai"]:
        if embedding_engine == "ollama":
            func = lambda query: generate_ollama_embeddings(
                GenerateEmbeddingsForm(
                    **{
                        "model": embedding_model,
                        "prompt": query,
                    }
                )
Steven Kreitzer's avatar
Steven Kreitzer committed
215
            )
216
217
218
219
220
221
222
223
224
225
        elif embedding_engine == "openai":
            func = lambda query: generate_openai_embeddings(
                model=embedding_model,
                text=query,
                key=openai_key,
                url=openai_url,
            )

        def generate_multiple(query, f):
            if isinstance(query, list):
226
227
228
229
230
231
232
                if embedding_engine == "openai":
                    embeddings = []
                    for i in range(0, len(query), batch_size):
                        embeddings.extend(f(query[i : i + batch_size]))
                    return embeddings
                else:
                    return [f(q) for q in query]
233
234
235
236
            else:
                return f(query)

        return lambda query: generate_multiple(query, func)
Steven Kreitzer's avatar
Steven Kreitzer committed
237
238


239
240
241
242
def rag_messages(
    docs,
    messages,
    template,
Timothy J. Baek's avatar
Timothy J. Baek committed
243
    embedding_function,
244
    k,
Timothy J. Baek's avatar
Timothy J. Baek committed
245
    reranking_function,
246
    r,
Timothy J. Baek's avatar
Timothy J. Baek committed
247
    hybrid_search,
248
):
Timothy J. Baek's avatar
Timothy J. Baek committed
249
    log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
Timothy J. Baek's avatar
Timothy J. Baek committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    last_user_message_idx = None
    for i in range(len(messages) - 1, -1, -1):
        if messages[i]["role"] == "user":
            last_user_message_idx = i
            break

    user_message = messages[last_user_message_idx]

    if isinstance(user_message["content"], list):
        # Handle list content input
        content_type = "list"
        query = ""
        for content_item in user_message["content"]:
            if content_item["type"] == "text":
                query = content_item["text"]
                break
    elif isinstance(user_message["content"], str):
        # Handle text content input
        content_type = "text"
        query = user_message["content"]
    else:
        # Fallback in case the input does not match expected types
        content_type = None
        query = ""

276
    extracted_collections = []
Timothy J. Baek's avatar
Timothy J. Baek committed
277
278
279
280
281
    relevant_contexts = []

    for doc in docs:
        context = None

282
283
284
285
286
287
288
289
        collection_names = (
            doc["collection_names"]
            if doc["type"] == "collection"
            else [doc["collection_name"]]
        )

        collection_names = set(collection_names).difference(extracted_collections)
        if not collection_names:
290
291
            log.debug(f"skipping {doc} as it has already been extracted")
            continue
292

293
        try:
294
            if doc["type"] == "text":
295
                context = doc["content"]
Timothy J. Baek's avatar
Timothy J. Baek committed
296
            else:
Timothy J. Baek's avatar
Timothy J. Baek committed
297
298
                if hybrid_search:
                    context = query_collection_with_hybrid_search(
299
                        collection_names=collection_names,
Timothy J. Baek's avatar
Timothy J. Baek committed
300
301
302
303
304
305
306
307
                        query=query,
                        embedding_function=embedding_function,
                        k=k,
                        reranking_function=reranking_function,
                        r=r,
                    )
                else:
                    context = query_collection(
308
                        collection_names=collection_names,
Timothy J. Baek's avatar
Timothy J. Baek committed
309
310
311
312
                        query=query,
                        embedding_function=embedding_function,
                        k=k,
                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
313
        except Exception as e:
314
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
315
316
            context = None

317
        if context:
318
            relevant_contexts.append({**context, "source": doc})
319

320
        extracted_collections.extend(collection_names)
Timothy J. Baek's avatar
Timothy J. Baek committed
321
322

    context_string = ""
Timothy J. Baek's avatar
Timothy J. Baek committed
323

324
    citations = []
Timothy J. Baek's avatar
Timothy J. Baek committed
325
    for context in relevant_contexts:
326
327
        try:
            if "documents" in context:
328
329
330
331
                context_string += "\n\n".join(
                    [text for text in context["documents"][0] if text is not None]
                )

332
333
334
                if "metadatas" in context:
                    citations.append(
                        {
335
                            "source": context["source"],
336
337
338
339
                            "document": context["documents"][0],
                            "metadata": context["metadatas"][0],
                        }
                    )
340
341
        except Exception as e:
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
342

343
    context_string = context_string.strip()
Timothy J. Baek's avatar
Timothy J. Baek committed
344
345
346
347
348
349
350

    ra_content = rag_template(
        template=template,
        context=context_string,
        query=query,
    )

351
352
    log.debug(f"ra_content: {ra_content}")

Timothy J. Baek's avatar
Timothy J. Baek committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    if content_type == "list":
        new_content = []
        for content_item in user_message["content"]:
            if content_item["type"] == "text":
                # Update the text item's content with ra_content
                new_content.append({"type": "text", "text": ra_content})
            else:
                # Keep other types of content as they are
                new_content.append(content_item)
        new_user_message = {**user_message, "content": new_content}
    else:
        new_user_message = {
            **user_message,
            "content": ra_content,
        }

    messages[last_user_message_idx] = new_user_message

371
    return messages, citations
372

Self Denial's avatar
Self Denial committed
373

374
375
376
377
378
379
380
381
382
383
384
def get_model_path(model: str, update_model: bool = False):
    # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
    cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")

    local_files_only = not update_model

    snapshot_kwargs = {
        "cache_dir": cache_dir,
        "local_files_only": local_files_only,
    }

Steven Kreitzer's avatar
Steven Kreitzer committed
385
    log.debug(f"model: {model}")
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    log.debug(f"snapshot_kwargs: {snapshot_kwargs}")

    # Inspiration from upstream sentence_transformers
    if (
        os.path.exists(model)
        or ("\\" in model or model.count("/") > 1)
        and local_files_only
    ):
        # If fully qualified path exists, return input, else set repo_id
        return model
    elif "/" not in model:
        # Set valid repo_id for model short-name
        model = "sentence-transformers" + "/" + model

    snapshot_kwargs["repo_id"] = model

    # Attempt to query the huggingface_hub library to determine the local path and/or to update
    try:
        model_repo_path = snapshot_download(**snapshot_kwargs)
        log.debug(f"model_repo_path: {model_repo_path}")
        return model_repo_path
    except Exception as e:
        log.exception(f"Cannot determine model snapshot path: {e}")
        return model


412
def generate_openai_embeddings(
413
414
415
416
    model: str,
    text: Union[str, list[str]],
    key: str,
    url: str = "https://api.openai.com/v1",
417
):
418
419
420
421
422
423
424
425
426
427
428
    if isinstance(text, list):
        embeddings = generate_openai_batch_embeddings(model, text, key, url)
    else:
        embeddings = generate_openai_batch_embeddings(model, [text], key, url)

    return embeddings[0] if isinstance(text, str) else embeddings


def generate_openai_batch_embeddings(
    model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
429
430
    try:
        r = requests.post(
Timothy J. Baek's avatar
Timothy J. Baek committed
431
            f"{url}/embeddings",
432
433
434
435
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {key}",
            },
436
            json={"input": texts, "model": model},
437
438
439
440
        )
        r.raise_for_status()
        data = r.json()
        if "data" in data:
441
            return [elem["embedding"] for elem in data["data"]]
442
443
444
445
446
        else:
            raise "Something went wrong :/"
    except Exception as e:
        print(e)
        return None
Steven Kreitzer's avatar
Steven Kreitzer committed
447
448
449
450
451


from typing import Any

from langchain_core.retrievers import BaseRetriever
452
from langchain_core.callbacks import CallbackManagerForRetrieverRun
Steven Kreitzer's avatar
Steven Kreitzer committed
453
454
455
456


class ChromaRetriever(BaseRetriever):
    collection: Any
Timothy J. Baek's avatar
Timothy J. Baek committed
457
    embedding_function: Any
458
    top_n: int
Steven Kreitzer's avatar
Steven Kreitzer committed
459
460
461
462
463
464
465

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
    ) -> List[Document]:
Timothy J. Baek's avatar
Timothy J. Baek committed
466
        query_embeddings = self.embedding_function(query)
Steven Kreitzer's avatar
Steven Kreitzer committed
467
468
469

        results = self.collection.query(
            query_embeddings=[query_embeddings],
470
            n_results=self.top_n,
Steven Kreitzer's avatar
Steven Kreitzer committed
471
472
473
474
475
476
        )

        ids = results["ids"][0]
        metadatas = results["metadatas"][0]
        documents = results["documents"][0]

Steven Kreitzer's avatar
Steven Kreitzer committed
477
478
479
480
481
482
483
        results = []
        for idx in range(len(ids)):
            results.append(
                Document(
                    metadata=metadatas[idx],
                    page_content=documents[idx],
                )
Steven Kreitzer's avatar
Steven Kreitzer committed
484
            )
Steven Kreitzer's avatar
Steven Kreitzer committed
485
        return results
486
487
488
489
490
491
492
493
494
495
496
497
498
499


import operator

from typing import Optional, Sequence

from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra

from sentence_transformers import util


class RerankCompressor(BaseDocumentCompressor):
Timothy J. Baek's avatar
Timothy J. Baek committed
500
    embedding_function: Any
Steven Kreitzer's avatar
Steven Kreitzer committed
501
    top_n: int
502
503
504
505
506
507
508
509
510
511
512
513
514
    reranking_function: Any
    r_score: float

    class Config:
        extra = Extra.forbid
        arbitrary_types_allowed = True

    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
Steven Kreitzer's avatar
Steven Kreitzer committed
515
516
517
        reranking = self.reranking_function is not None

        if reranking:
518
519
520
521
            scores = self.reranking_function.predict(
                [(query, doc.page_content) for doc in documents]
            )
        else:
Timothy J. Baek's avatar
Timothy J. Baek committed
522
523
            query_embedding = self.embedding_function(query)
            document_embedding = self.embedding_function(
524
525
526
527
528
529
530
531
532
533
                [doc.page_content for doc in documents]
            )
            scores = util.cos_sim(query_embedding, document_embedding)[0]

        docs_with_scores = list(zip(documents, scores.tolist()))
        if self.r_score:
            docs_with_scores = [
                (d, s) for d, s in docs_with_scores if s >= self.r_score
            ]

Steven Kreitzer's avatar
Steven Kreitzer committed
534
        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
535
536
537
538
539
540
541
542
543
544
        final_results = []
        for doc, doc_score in result[: self.top_n]:
            metadata = doc.metadata
            metadata["score"] = doc_score
            doc = Document(
                page_content=doc.page_content,
                metadata=metadata,
            )
            final_results.append(doc)
        return final_results