"src/vscode:/vscode.git/clone" did not exist on "58711bbc2da5db8126c775f2163e3562098b35ab"
utils.py 14.1 KB
Newer Older
1
import os
2
import logging
3
4
import requests

5
from typing import List
6

7
8
9
10
from apps.ollama.main import (
    generate_ollama_embeddings,
    GenerateEmbeddingsForm,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
11

12
13
from huggingface_hub import snapshot_download

14
15
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
Steven Kreitzer's avatar
Steven Kreitzer committed
16
from langchain.retrievers import (
17
    ContextualCompressionRetriever,
Steven Kreitzer's avatar
Steven Kreitzer committed
18
19
20
    EnsembleRetriever,
)

21
from typing import Optional
22
23
from config import SRC_LOG_LEVELS, CHROMA_CLIENT

24

25
26
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
Timothy J. Baek's avatar
Timothy J. Baek committed
27
28


Steven Kreitzer's avatar
Steven Kreitzer committed
29
30
31
32
def query_embeddings_doc(
    collection_name: str,
    query: str,
    embeddings_function,
Steven Kreitzer's avatar
Steven Kreitzer committed
33
    reranking_function,
34
    k: int,
Steven Kreitzer's avatar
Steven Kreitzer committed
35
    r: int,
Timothy J. Baek's avatar
Timothy J. Baek committed
36
    hybrid_search: bool,
Steven Kreitzer's avatar
Steven Kreitzer committed
37
):
38
    try:
Steven Kreitzer's avatar
Steven Kreitzer committed
39
        collection = CHROMA_CLIENT.get_collection(name=collection_name)
40

Timothy J. Baek's avatar
Timothy J. Baek committed
41
        if hybrid_search:
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            documents = collection.get()  # get all documents
            bm25_retriever = BM25Retriever.from_texts(
                texts=documents.get("documents"),
                metadatas=documents.get("metadatas"),
            )
            bm25_retriever.k = k

            chroma_retriever = ChromaRetriever(
                collection=collection,
                embeddings_function=embeddings_function,
                top_n=k,
            )

            ensemble_retriever = EnsembleRetriever(
                retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
            )

            compressor = RerankCompressor(
                embeddings_function=embeddings_function,
                reranking_function=reranking_function,
                r_score=r,
                top_n=k,
            )

            compression_retriever = ContextualCompressionRetriever(
                base_compressor=compressor, base_retriever=ensemble_retriever
            )

            result = compression_retriever.invoke(query)
            result = {
                "distances": [[d.metadata.get("score") for d in result]],
                "documents": [[d.page_content for d in result]],
                "metadatas": [[d.metadata for d in result]],
            }
        else:
            query_embeddings = embeddings_function(query)
            result = collection.query(
                query_embeddings=[query_embeddings],
                n_results=k,
            )

Steven Kreitzer's avatar
Steven Kreitzer committed
83
        log.info(f"query_embeddings_doc:result {result}")
84
85
86
87
88
        return result
    except Exception as e:
        raise e


Steven Kreitzer's avatar
Steven Kreitzer committed
89
def merge_and_sort_query_results(query_results, k, reverse=False):
Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
92
    # Initialize lists to store combined data
    combined_distances = []
    combined_documents = []
Steven Kreitzer's avatar
Steven Kreitzer committed
93
    combined_metadatas = []
Timothy J. Baek's avatar
Timothy J. Baek committed
94
95
96
97

    for data in query_results:
        combined_distances.extend(data["distances"][0])
        combined_documents.extend(data["documents"][0])
Steven Kreitzer's avatar
Steven Kreitzer committed
98
        combined_metadatas.extend(data["metadatas"][0])
Timothy J. Baek's avatar
Timothy J. Baek committed
99

Steven Kreitzer's avatar
Steven Kreitzer committed
100
    # Create a list of tuples (distance, document, metadata)
101
    combined = list(zip(combined_distances, combined_documents, combined_metadatas))
Timothy J. Baek's avatar
Timothy J. Baek committed
102
103

    # Sort the list based on distances
Steven Kreitzer's avatar
Steven Kreitzer committed
104
    combined.sort(key=lambda x: x[0], reverse=reverse)
Timothy J. Baek's avatar
Timothy J. Baek committed
105

106
107
108
109
110
111
112
113
    # We don't have anything :-(
    if not combined:
        sorted_distances = []
        sorted_documents = []
        sorted_metadatas = []
    else:
        # Unzip the sorted list
        sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
Timothy J. Baek's avatar
Timothy J. Baek committed
114

115
116
117
118
        # Slicing the lists to include only k elements
        sorted_distances = list(sorted_distances)[:k]
        sorted_documents = list(sorted_documents)[:k]
        sorted_metadatas = list(sorted_metadatas)[:k]
Timothy J. Baek's avatar
Timothy J. Baek committed
119
120

    # Create the output dictionary
121
    result = {
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123
        "distances": [sorted_distances],
        "documents": [sorted_documents],
Steven Kreitzer's avatar
Steven Kreitzer committed
124
        "metadatas": [sorted_metadatas],
Timothy J. Baek's avatar
Timothy J. Baek committed
125
126
    }

127
    return result
Timothy J. Baek's avatar
Timothy J. Baek committed
128
129


130
def query_embeddings_collection(
Steven Kreitzer's avatar
Steven Kreitzer committed
131
132
133
    collection_names: List[str],
    query: str,
    k: int,
134
    r: float,
Steven Kreitzer's avatar
Steven Kreitzer committed
135
136
    embeddings_function,
    reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
137
    hybrid_search: bool,
Timothy J. Baek's avatar
Timothy J. Baek committed
138
139
):

140
    results = []
Timothy J. Baek's avatar
Timothy J. Baek committed
141

142
143
    for collection_name in collection_names:
        try:
144
145
146
147
            result = query_embeddings_doc(
                collection_name=collection_name,
                query=query,
                k=k,
148
                r=r,
Steven Kreitzer's avatar
Steven Kreitzer committed
149
150
                embeddings_function=embeddings_function,
                reranking_function=reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
151
                hybrid_search=hybrid_search,
152
153
154
155
156
            )
            results.append(result)
        except:
            pass

Steven Kreitzer's avatar
Steven Kreitzer committed
157
158
    reverse = hybrid and reranking_function is not None
    return merge_and_sort_query_results(results, k=k, reverse=reverse)
159
160


Timothy J. Baek's avatar
Timothy J. Baek committed
161
def rag_template(template: str, context: str, query: str):
162
163
    template = template.replace("[context]", context)
    template = template.replace("[query]", query)
Timothy J. Baek's avatar
Timothy J. Baek committed
164
    return template
Timothy J. Baek's avatar
Timothy J. Baek committed
165
166


Timothy J. Baek's avatar
Timothy J. Baek committed
167
def get_embeddings_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
168
169
170
171
172
173
174
175
    embedding_engine,
    embedding_model,
    embedding_function,
    openai_key,
    openai_url,
):
    if embedding_engine == "":
        return lambda query: embedding_function.encode(query).tolist()
176
177
178
179
180
181
182
183
184
    elif embedding_engine in ["ollama", "openai"]:
        if embedding_engine == "ollama":
            func = lambda query: generate_ollama_embeddings(
                GenerateEmbeddingsForm(
                    **{
                        "model": embedding_model,
                        "prompt": query,
                    }
                )
Steven Kreitzer's avatar
Steven Kreitzer committed
185
            )
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        elif embedding_engine == "openai":
            func = lambda query: generate_openai_embeddings(
                model=embedding_model,
                text=query,
                key=openai_key,
                url=openai_url,
            )

        def generate_multiple(query, f):
            if isinstance(query, list):
                return [f(q) for q in query]
            else:
                return f(query)

        return lambda query: generate_multiple(query, func)
Steven Kreitzer's avatar
Steven Kreitzer committed
201
202


203
204
205
206
207
def rag_messages(
    docs,
    messages,
    template,
    k,
208
    r,
Timothy J. Baek's avatar
Timothy J. Baek committed
209
    hybrid_search,
210
211
212
    embedding_engine,
    embedding_model,
    embedding_function,
Steven Kreitzer's avatar
Steven Kreitzer committed
213
    reranking_function,
214
215
216
    openai_key,
    openai_url,
):
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
217
    log.debug(
Steven Kreitzer's avatar
Steven Kreitzer committed
218
        f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
219
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

    last_user_message_idx = None
    for i in range(len(messages) - 1, -1, -1):
        if messages[i]["role"] == "user":
            last_user_message_idx = i
            break

    user_message = messages[last_user_message_idx]

    if isinstance(user_message["content"], list):
        # Handle list content input
        content_type = "list"
        query = ""
        for content_item in user_message["content"]:
            if content_item["type"] == "text":
                query = content_item["text"]
                break
    elif isinstance(user_message["content"], str):
        # Handle text content input
        content_type = "text"
        query = user_message["content"]
    else:
        # Fallback in case the input does not match expected types
        content_type = None
        query = ""

Timothy J. Baek's avatar
Timothy J. Baek committed
246
    embeddings_function = get_embeddings_function(
247
248
249
250
251
252
253
254
        embedding_engine,
        embedding_model,
        embedding_function,
        openai_key,
        openai_url,
    )

    extracted_collections = []
Timothy J. Baek's avatar
Timothy J. Baek committed
255
256
257
258
259
    relevant_contexts = []

    for doc in docs:
        context = None

260
261
262
263
264
265
266
267
268
269
        collection = doc.get("collection_name")
        if collection:
            collection = [collection]
        else:
            collection = doc.get("collection_names", [])

        collection = set(collection).difference(extracted_collections)
        if not collection:
            log.debug(f"skipping {doc} as it has already been extracted")
            continue
270

271
        try:
272
            if doc["type"] == "text":
273
                context = doc["content"]
274
275
276
277
278
279
280
281
            elif doc["type"] == "collection":
                context = query_embeddings_collection(
                    collection_names=doc["collection_names"],
                    query=query,
                    k=k,
                    r=r,
                    embeddings_function=embeddings_function,
                    reranking_function=reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
282
                    hybrid_search=hybrid_search,
283
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
284
            else:
285
286
287
288
289
290
291
                context = query_embeddings_doc(
                    collection_name=doc["collection_name"],
                    query=query,
                    k=k,
                    r=r,
                    embeddings_function=embeddings_function,
                    reranking_function=reranking_function,
Timothy J. Baek's avatar
Timothy J. Baek committed
292
                    hybrid_search=hybrid_search,
Steven Kreitzer's avatar
Steven Kreitzer committed
293
                )
Timothy J. Baek's avatar
Timothy J. Baek committed
294
        except Exception as e:
295
            log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
296
297
            context = None

298
299
300
301
        if context:
            relevant_contexts.append(context)

        extracted_collections.extend(collection)
Timothy J. Baek's avatar
Timothy J. Baek committed
302
303
304

    context_string = ""
    for context in relevant_contexts:
305
306
307
        items = context["documents"][0]
        context_string += "\n\n".join(items)
    context_string = context_string.strip()
Timothy J. Baek's avatar
Timothy J. Baek committed
308
309
310
311
312
313
314

    ra_content = rag_template(
        template=template,
        context=context_string,
        query=query,
    )

315
316
    log.debug(f"ra_content: {ra_content}")

Timothy J. Baek's avatar
Timothy J. Baek committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    if content_type == "list":
        new_content = []
        for content_item in user_message["content"]:
            if content_item["type"] == "text":
                # Update the text item's content with ra_content
                new_content.append({"type": "text", "text": ra_content})
            else:
                # Keep other types of content as they are
                new_content.append(content_item)
        new_user_message = {**user_message, "content": new_content}
    else:
        new_user_message = {
            **user_message,
            "content": ra_content,
        }

    messages[last_user_message_idx] = new_user_message

    return messages
336

Self Denial's avatar
Self Denial committed
337

338
339
340
341
342
343
344
345
346
347
348
def get_model_path(model: str, update_model: bool = False):
    # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
    cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")

    local_files_only = not update_model

    snapshot_kwargs = {
        "cache_dir": cache_dir,
        "local_files_only": local_files_only,
    }

Steven Kreitzer's avatar
Steven Kreitzer committed
349
    log.debug(f"model: {model}")
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    log.debug(f"snapshot_kwargs: {snapshot_kwargs}")

    # Inspiration from upstream sentence_transformers
    if (
        os.path.exists(model)
        or ("\\" in model or model.count("/") > 1)
        and local_files_only
    ):
        # If fully qualified path exists, return input, else set repo_id
        return model
    elif "/" not in model:
        # Set valid repo_id for model short-name
        model = "sentence-transformers" + "/" + model

    snapshot_kwargs["repo_id"] = model

    # Attempt to query the huggingface_hub library to determine the local path and/or to update
    try:
        model_repo_path = snapshot_download(**snapshot_kwargs)
        log.debug(f"model_repo_path: {model_repo_path}")
        return model_repo_path
    except Exception as e:
        log.exception(f"Cannot determine model snapshot path: {e}")
        return model


376
def generate_openai_embeddings(
Timothy J. Baek's avatar
Timothy J. Baek committed
377
    model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
378
379
380
):
    try:
        r = requests.post(
Timothy J. Baek's avatar
Timothy J. Baek committed
381
            f"{url}/embeddings",
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {key}",
            },
            json={"input": text, "model": model},
        )
        r.raise_for_status()
        data = r.json()
        if "data" in data:
            return data["data"][0]["embedding"]
        else:
            raise "Something went wrong :/"
    except Exception as e:
        print(e)
        return None
Steven Kreitzer's avatar
Steven Kreitzer committed
397
398
399
400
401


from typing import Any

from langchain_core.retrievers import BaseRetriever
402
from langchain_core.callbacks import CallbackManagerForRetrieverRun
Steven Kreitzer's avatar
Steven Kreitzer committed
403
404
405
406
407


class ChromaRetriever(BaseRetriever):
    collection: Any
    embeddings_function: Any
408
    top_n: int
Steven Kreitzer's avatar
Steven Kreitzer committed
409
410
411
412
413
414
415
416
417
418
419

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
    ) -> List[Document]:
        query_embeddings = self.embeddings_function(query)

        results = self.collection.query(
            query_embeddings=[query_embeddings],
420
            n_results=self.top_n,
Steven Kreitzer's avatar
Steven Kreitzer committed
421
422
423
424
425
426
427
428
429
430
431
432
433
        )

        ids = results["ids"][0]
        metadatas = results["metadatas"][0]
        documents = results["documents"][0]

        return [
            Document(
                metadata=metadatas[idx],
                page_content=documents[idx],
            )
            for idx in range(len(ids))
        ]
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479


import operator

from typing import Optional, Sequence

from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra

from sentence_transformers import util


class RerankCompressor(BaseDocumentCompressor):
    embeddings_function: Any
    reranking_function: Any
    r_score: float
    top_n: int

    class Config:
        extra = Extra.forbid
        arbitrary_types_allowed = True

    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
        if self.reranking_function:
            scores = self.reranking_function.predict(
                [(query, doc.page_content) for doc in documents]
            )
        else:
            query_embedding = self.embeddings_function(query)
            document_embedding = self.embeddings_function(
                [doc.page_content for doc in documents]
            )
            scores = util.cos_sim(query_embedding, document_embedding)[0]

        docs_with_scores = list(zip(documents, scores.tolist()))
        if self.r_score:
            docs_with_scores = [
                (d, s) for d, s in docs_with_scores if s >= self.r_score
            ]

Steven Kreitzer's avatar
Steven Kreitzer committed
480
481
482
        reverse = self.reranking_function is not None
        result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)

483
484
485
486
487
488
489
490
491
492
        final_results = []
        for doc, doc_score in result[: self.top_n]:
            metadata = doc.metadata
            metadata["score"] = doc_score
            doc = Document(
                page_content=doc.page_content,
                metadata=metadata,
            )
            final_results.append(doc)
        return final_results