main.py 42.5 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
Que Nguyen's avatar
Que Nguyen committed
11
import requests
12
import os, shutil, logging, re
mindspawn's avatar
mindspawn committed
13
from datetime import datetime
14
15

from pathlib import Path
16
from typing import List, Union, Sequence, Iterator, Any
Timothy J. Baek's avatar
Timothy J. Baek committed
17

18
from chromadb.utils.batch_utils import create_batches
19
from langchain_core.documents import Document
Timothy J. Baek's avatar
Timothy J. Baek committed
20

Timothy J. Baek's avatar
Timothy J. Baek committed
21
22
23
24
25
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
26
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
28
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
29
30
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
31
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
32
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
33
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
34
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
35
    YoutubeLoader,
mindspawn's avatar
mindspawn committed
36
    OutlookMessageLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
37
)
38
39
from langchain.text_splitter import RecursiveCharacterTextSplitter

40
41
42
43
44
import validators
import urllib.parse
import socket


45
46
from pydantic import BaseModel
from typing import Optional
47
import mimetypes
48
import uuid
49
50
import json

51
import sentence_transformers
52

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
53
from apps.webui.models.documents import (
54
55
56
57
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
58

59
from apps.rag.utils import (
60
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
61
62
63
64
65
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
66
)
Timothy J. Baek's avatar
Timothy J. Baek committed
67

Timothy J. Baek's avatar
Timothy J. Baek committed
68
69
70
71
72
73
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
74
from apps.rag.search.serply import search_serply
75
from apps.rag.search.duckduckgo import search_duckduckgo
Timothy J. Baek's avatar
Timothy J. Baek committed
76

77
78
79
80
81
82
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
83
from utils.utils import get_current_user, get_admin_user
84

85
from config import (
86
    AppConfig,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
87
    ENV,
88
    SRC_LOG_LEVELS,
89
90
    UPLOAD_DIR,
    DOCS_DIR,
91
92
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
93
    RAG_EMBEDDING_ENGINE,
94
    RAG_EMBEDDING_MODEL,
95
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
96
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
97
    ENABLE_RAG_HYBRID_SEARCH,
98
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
99
    RAG_RERANKING_MODEL,
100
    PDF_EXTRACT_IMAGES,
101
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
102
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
103
104
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
105
    DEVICE_TYPE,
106
107
108
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
109
    RAG_TEMPLATE,
110
    ENABLE_RAG_LOCAL_WEB_FETCH,
111
    YOUTUBE_LOADER_LANGUAGE,
Timothy J. Baek's avatar
Timothy J. Baek committed
112
    ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
113
    RAG_WEB_SEARCH_ENGINE,
114
    RAG_WEB_SEARCH_WHITE_LIST_DOMAINS,
Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
    SEARXNG_QUERY_URL,
    GOOGLE_PSE_API_KEY,
    GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
118
    BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
119
120
121
    SERPSTACK_API_KEY,
    SERPSTACK_HTTPS,
    SERPER_API_KEY,
122
    SERPLY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
123
    RAG_WEB_SEARCH_RESULT_COUNT,
124
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
125
    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
126
)
127

128
129
from constants import ERROR_MESSAGES

130
131
132
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
133
134
app = FastAPI()

135
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
136

137
138
139
140
141
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
142
143
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
144

145
146
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
147

148
149
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
150
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
151
152
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
153

154

155
156
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
157

158
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
159

Steven Kreitzer's avatar
Steven Kreitzer committed
160

161
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
162
163
164
app.state.YOUTUBE_LOADER_TRANSLATION = None


Timothy J. Baek's avatar
Timothy J. Baek committed
165
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
Timothy J. Baek's avatar
Timothy J. Baek committed
166
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
167
app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS = RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
Timothy J. Baek's avatar
Timothy J. Baek committed
168

Timothy J. Baek's avatar
Timothy J. Baek committed
169
170
171
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
Timothy J. Baek's avatar
Timothy J. Baek committed
172
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
173
174
175
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
176
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
177
178
179
180
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS


181
182
183
184
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
185
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
210
    app.state.config.RAG_EMBEDDING_MODEL,
211
212
213
214
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
215
    app.state.config.RAG_RERANKING_MODEL,
216
217
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
218

Timothy J. Baek's avatar
Timothy J. Baek committed
219
220

app.state.EMBEDDING_FUNCTION = get_embedding_function(
221
222
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
223
    app.state.sentence_transformer_ef,
224
225
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
226
    app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
227
228
)

Timothy J. Baek's avatar
Timothy J. Baek committed
229
230
origins = ["*"]

231

Timothy J. Baek's avatar
Timothy J. Baek committed
232
233
234
235
236
237
238
239
240
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
241
class CollectionNameForm(BaseModel):
242
243
244
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
245
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
246
247
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
248

249
250
251
252
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
253
254
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
255
256
    return {
        "status": True,
257
258
259
260
261
262
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
263
        "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
264
265
266
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
267
268
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
269
270
    return {
        "status": True,
271
272
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
273
        "openai_config": {
274
275
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
276
            "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
277
        },
278
279
280
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
281
282
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
283
284
    return {
        "status": True,
285
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
286
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
287
288


289
290
291
class OpenAIConfigForm(BaseModel):
    url: str
    key: str
292
    batch_size: Optional[int] = None
293
294


295
class EmbeddingModelUpdateForm(BaseModel):
296
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
297
    embedding_engine: str
298
299
300
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
301
302
@app.post("/embedding/update")
async def update_embedding_config(
303
304
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
305
    log.info(
306
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
307
    )
308
    try:
309
310
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
311

312
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
313
            if form_data.openai_config is not None:
314
315
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
316
317
318
319
320
                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
                    form_data.openai_config.batch_size
                    if form_data.openai_config.batch_size
                    else 1
                )
321

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
322
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
323

Timothy J. Baek's avatar
Timothy J. Baek committed
324
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
325
326
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
327
            app.state.sentence_transformer_ef,
328
329
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
330
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
331
332
        )

333
334
        return {
            "status": True,
335
336
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
337
            "openai_config": {
338
339
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
340
                "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
341
            },
342
343
344
345
346
347
348
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
349
350


Steven Kreitzer's avatar
Steven Kreitzer committed
351
352
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
353

Steven Kreitzer's avatar
Steven Kreitzer committed
354
355
356
357
358
359

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
360
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
361
362
    )
    try:
363
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
364

365
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
366
367
368

        return {
            "status": True,
369
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
370
371
372
373
374
375
376
377
378
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
379
380
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
381
382
    return {
        "status": True,
383
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
384
        "chunk": {
385
386
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
387
        },
388
        "youtube": {
389
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
390
391
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
392
        "web": {
Timothy J. Baek's avatar
Timothy J. Baek committed
393
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
394
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
395
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
396
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
397
398
399
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
400
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
401
402
403
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
404
                "serply_api_key": app.state.config.SERPLY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
405
406
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
Timothy J. Baek's avatar
Timothy J. Baek committed
407
            },
Timothy J. Baek's avatar
Timothy J. Baek committed
408
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
409
410
411
412
413
414
415
416
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


417
418
419
420
421
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
422
class WebSearchConfig(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
423
    enabled: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
424
    engine: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
425
426
427
    searxng_query_url: Optional[str] = None
    google_pse_api_key: Optional[str] = None
    google_pse_engine_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
428
    brave_search_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
429
430
431
    serpstack_api_key: Optional[str] = None
    serpstack_https: Optional[bool] = None
    serper_api_key: Optional[str] = None
432
    serply_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
433
434
435
436
    result_count: Optional[int] = None
    concurrent_requests: Optional[int] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
437
438
439
440
441
class WebConfig(BaseModel):
    search: WebSearchConfig
    web_loader_ssl_verification: Optional[bool] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
442
class ConfigUpdateForm(BaseModel):
443
444
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
445
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
446
    web: Optional[WebConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
447
448
449
450


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
451
    app.state.config.PDF_EXTRACT_IMAGES = (
452
        form_data.pdf_extract_images
453
454
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
455
456
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
457
458
459
    if form_data.chunk is not None:
        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
460

Timothy J. Baek's avatar
Timothy J. Baek committed
461
462
463
    if form_data.youtube is not None:
        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
464

Timothy J. Baek's avatar
Timothy J. Baek committed
465
466
467
468
    if form_data.web is not None:
        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
            form_data.web.web_loader_ssl_verification
        )
469

Timothy J. Baek's avatar
Timothy J. Baek committed
470
        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
Timothy J. Baek's avatar
Timothy J. Baek committed
471
472
473
474
475
476
477
478
479
480
481
482
        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
        app.state.config.GOOGLE_PSE_ENGINE_ID = (
            form_data.web.search.google_pse_engine_id
        )
        app.state.config.BRAVE_SEARCH_API_KEY = (
            form_data.web.search.brave_search_api_key
        )
        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
483
        app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
Timothy J. Baek's avatar
Timothy J. Baek committed
484
485
486
487
        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
            form_data.web.search.concurrent_requests
        )
488

Timothy J. Baek's avatar
Timothy J. Baek committed
489
490
    return {
        "status": True,
491
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
492
        "chunk": {
493
494
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
495
        },
496
        "youtube": {
497
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
498
499
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
500
501
502
        "web": {
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
503
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
504
505
506
507
508
509
510
511
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
512
                "serply_api_key": app.state.config.SERPLY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
513
514
515
516
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
            },
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
517
    }
518
519


Timothy J. Baek's avatar
Timothy J. Baek committed
520
521
522
523
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
524
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
525
526
527
    }


528
529
530
531
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
532
533
534
535
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
536
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
537
538


539
540
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
541
    r: Optional[float] = None
542
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
543
    hybrid: Optional[bool] = None
544
545
546
547
548
549


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
550
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
551
        form_data.template if form_data.template else RAG_TEMPLATE
552
    )
553
554
555
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
556
        form_data.hybrid if form_data.hybrid else False
557
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
558
559
    return {
        "status": True,
560
561
562
563
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
564
    }
565
566


567
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
568
569
    collection_name: str
    query: str
570
    k: Optional[int] = None
571
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
572
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
573
574


575
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
576
def query_doc_handler(
577
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
578
579
    user=Depends(get_current_user),
):
580
    try:
581
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
582
583
584
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
585
                embedding_function=app.state.EMBEDDING_FUNCTION,
586
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
587
                reranking_function=app.state.sentence_transformer_rf,
588
                r=(
589
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
590
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
591
592
593
594
595
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
596
                embedding_function=app.state.EMBEDDING_FUNCTION,
597
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
598
            )
599
    except Exception as e:
600
        log.exception(e)
601
602
603
604
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
605
606


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
607
608
609
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
610
    k: Optional[int] = None
611
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
612
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
613
614


615
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
616
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
617
618
619
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
620
    try:
621
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
622
623
624
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
625
                embedding_function=app.state.EMBEDDING_FUNCTION,
626
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
627
                reranking_function=app.state.sentence_transformer_rf,
628
                r=(
629
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
630
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
631
632
633
634
635
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
636
                embedding_function=app.state.EMBEDDING_FUNCTION,
637
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
638
            )
639

640
641
642
643
644
645
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
646
647


Timothy J. Baek's avatar
Timothy J. Baek committed
648
649
650
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
651
652
653
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
654
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
655
656
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


677
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
678
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
679
680
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
681
        loader = get_web_loader(
682
            form_data.url,
683
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
684
        )
685
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
686
687
688
689
690

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

691
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
692
693
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
694
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
695
696
            "filename": form_data.url,
        }
697
    except Exception as e:
698
        log.exception(e)
699
700
701
702
703
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

704

705
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
706
    # Check if the URL is valid
707
    if not validate_url(url):
708
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
709
    return SafeWebBaseLoader(
710
711
712
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
713
        continue_on_failure=True,
714
    )
715
716


717
718
719
720
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
721
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
722
723
724
725
726
727
728
729
730
731
732
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
733
734
735
736
737
738
739
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False

Timothy J. Baek's avatar
Timothy J. Baek committed
740

Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
741
742
743
744
745
746
747
748
749
750
751
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


Timothy J. Baek's avatar
Timothy J. Baek committed
752
753
754
755
756
757
758
759
def search_web(engine: str, query: str) -> list[SearchResult]:
    """Search the web using a search engine and return the results as a list of SearchResult objects.
    Will look for a search engine API key in environment variables in the following order:
    - SEARXNG_QUERY_URL
    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
    - BRAVE_SEARCH_API_KEY
    - SERPSTACK_API_KEY
    - SERPER_API_KEY
760
    - SERPLY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
761
762
763
764
765
766
767
768

    Args:
        query (str): The query to search for
    """

    # TODO: add playwright to search the web
    if engine == "searxng":
        if app.state.config.SEARXNG_QUERY_URL:
Timothy J. Baek's avatar
Timothy J. Baek committed
769
770
771
772
            return search_searxng(
                app.state.config.SEARXNG_QUERY_URL,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
773
                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
Timothy J. Baek's avatar
Timothy J. Baek committed
774
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
775
776
777
778
779
780
781
782
783
784
785
        else:
            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
    elif engine == "google_pse":
        if (
            app.state.config.GOOGLE_PSE_API_KEY
            and app.state.config.GOOGLE_PSE_ENGINE_ID
        ):
            return search_google_pse(
                app.state.config.GOOGLE_PSE_API_KEY,
                app.state.config.GOOGLE_PSE_ENGINE_ID,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
786
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
787
                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
Timothy J. Baek's avatar
Timothy J. Baek committed
788
789
790
791
792
793
794
            )
        else:
            raise Exception(
                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
            )
    elif engine == "brave":
        if app.state.config.BRAVE_SEARCH_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
795
796
797
798
            return search_brave(
                app.state.config.BRAVE_SEARCH_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
799
                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
Timothy J. Baek's avatar
Timothy J. Baek committed
800
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
801
802
803
804
805
806
807
        else:
            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
    elif engine == "serpstack":
        if app.state.config.SERPSTACK_API_KEY:
            return search_serpstack(
                app.state.config.SERPSTACK_API_KEY,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
808
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
809
                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS,
Timothy J. Baek's avatar
Timothy J. Baek committed
810
811
812
813
814
815
                https_enabled=app.state.config.SERPSTACK_HTTPS,
            )
        else:
            raise Exception("No SERPSTACK_API_KEY found in environment variables")
    elif engine == "serper":
        if app.state.config.SERPER_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
816
817
818
819
            return search_serper(
                app.state.config.SERPER_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
820
                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
Timothy J. Baek's avatar
Timothy J. Baek committed
821
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
822
823
        else:
            raise Exception("No SERPER_API_KEY found in environment variables")
824
825
826
827
828
829
    elif engine == "serply":
        if app.state.config.SERPLY_API_KEY:
            return search_serply(
                app.state.config.SERPLY_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
830
                app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
831
832
833
            )
        else:
            raise Exception("No SERPLY_API_KEY found in environment variables")
834
    elif engine == "duckduckgo":
835
        return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS)
Timothy J. Baek's avatar
Timothy J. Baek committed
836
837
838
839
    else:
        raise Exception("No search engine API key found in environment variables")


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
840
841
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
842
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
843
844
845
        logging.info(
            f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
846
847
848
849
850
851
852
853
854
855
856
857
858
        web_results = search_web(
            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
        )
    except Exception as e:
        log.exception(e)

        print(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
        )

    try:
859
860
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
861
        data = loader.load()
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


881
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
882

883
    text_splitter = RecursiveCharacterTextSplitter(
884
885
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
886
887
        add_start_index=True,
    )
888

889
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
890
891

    if len(docs) > 0:
892
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
893
894
895
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
896
897
898


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
899
    text, metadata, collection_name, overwrite: bool = False
900
901
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
902
903
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
904
905
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
906
    docs = text_splitter.create_documents([text], metadatas=[metadata])
907
908
909
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
910
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
911
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
912

913
914
915
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

mindspawn's avatar
mindspawn committed
916
917
918
919
920
921
922
    # ChromaDB does not like datetime formats
    # for meta-data so convert them to string.
    for metadata in metadatas:
        for key, value in metadata.items():
            if isinstance(value, datetime):
                metadata[key] = str(value)

923
924
925
926
    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
927
                    log.info(f"deleting existing collection {collection_name}")
928
929
                    CHROMA_CLIENT.delete_collection(name=collection_name)

930
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
931

Timothy J. Baek's avatar
Timothy J. Baek committed
932
        embedding_func = get_embedding_function(
933
934
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
935
            app.state.sentence_transformer_ef,
936
937
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
938
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Steven Kreitzer's avatar
Steven Kreitzer committed
939
940
941
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
942
        embeddings = embedding_func(embedding_texts)
943
944
945

        for batch in create_batches(
            api=CHROMA_CLIENT,
946
            ids=[str(uuid.uuid4()) for _ in texts],
947
948
949
950
951
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
952

953
        return True
954
    except Exception as e:
955
        log.exception(e)
956
957
958
959
960
961
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


962
963
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
mindspawn's avatar
mindspawn committed
1009
        "msg",
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1010
1011
1012
    ]

    if file_ext == "pdf":
1013
        loader = PyPDFLoader(
1014
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
1015
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1016
1017
1018
1019
1020
1021
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
1022
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
1023
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1024
1025
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
1026
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1027
1028
        loader = UnstructuredEPubLoader(file_path)
    elif (
1029
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1030
1031
1032
1033
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
1034
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1035
1036
1037
1038
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1039
1040
1041
1042
1043
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
mindspawn's avatar
mindspawn committed
1044
1045
    elif file_ext == "msg":
        loader = OutlookMessageLoader(file_path)
1046
1047
1048
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
1049
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1050
    else:
1051
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1052
1053
1054
1055
1056
        known_type = False

    return loader, known_type


1057
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
1058
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
1059
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
1060
1061
1062
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
1063
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
1064

1065
    log.info(f"file.content_type: {file.content_type}")
1066
    try:
1067
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
1068
        filename = os.path.basename(unsanitized_filename)
1069

Timothy J. Baek's avatar
Timothy J. Baek committed
1070
        file_path = f"{UPLOAD_DIR}/{filename}"
1071

1072
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
1073
        with open(file_path, "wb") as f:
1074
1075
1076
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1077
1078
1079
1080
1081
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1082
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1083
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
1096
1097
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
1098
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
1099
            )
1100
    except Exception as e:
1101
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
1112
1113


1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
1130
1131
1132
1133
1134
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


1145
1146
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
1147
1148
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
1149
1150
1151
1152
1153
1154
1155
1156
1157
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1158
1159
1160
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
1161
1162
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
1187
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
1188
1189
1190
1191
1192
1193
1194
1195
1196
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
1197
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
1198
                    pass
1199

1200
        except Exception as e:
1201
            log.exception(e)
1202
1203
1204
1205

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1206
@app.get("/reset/db")
1207
1208
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
1209
1210


Timothy J. Baek's avatar
Timothy J. Baek committed
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
@app.get("/reset/uploads")
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    try:
        # Check if the directory exists
        if os.path.exists(folder):
            # Iterate over all the files and directories in the specified directory
            for filename in os.listdir(folder):
                file_path = os.path.join(folder, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)  # Remove the file or link
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)  # Remove the directory
                except Exception as e:
                    print(f"Failed to delete {file_path}. Reason: {e}")
        else:
            print(f"The directory {folder} does not exist")
    except Exception as e:
        print(f"Failed to process the directory {folder}. Reason: {e}")

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1235
@app.get("/reset")
1236
1237
1238
1239
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
1240
        try:
1241
1242
1243
1244
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1245
        except Exception as e:
1246
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1247

1248
1249
1250
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1251
        log.exception(e)
1252
1253

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1254

Timothy J. Baek's avatar
Timothy J. Baek committed
1255

1256
1257
class SafeWebBaseLoader(WebBaseLoader):
    """WebBaseLoader with enhanced error handling for URLs."""
Timothy J. Baek's avatar
Timothy J. Baek committed
1258

1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
    def lazy_load(self) -> Iterator[Document]:
        """Lazy load text from the url(s) in web_path with error handling."""
        for path in self.web_paths:
            try:
                soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
                text = soup.get_text(**self.bs_get_text_kwargs)

                # Build metadata
                metadata = {"source": path}
                if title := soup.find("title"):
                    metadata["title"] = title.get_text()
                if description := soup.find("meta", attrs={"name": "description"}):
Timothy J. Baek's avatar
Timothy J. Baek committed
1271
1272
1273
                    metadata["description"] = description.get(
                        "content", "No description found."
                    )
1274
1275
                if html := soup.find("html"):
                    metadata["language"] = html.get("lang", "No language found.")
Timothy J. Baek's avatar
Timothy J. Baek committed
1276

1277
1278
1279
1280
                yield Document(page_content=text, metadata=metadata)
            except Exception as e:
                # Log the error and continue with the next URL
                log.error(f"Error loading {path}: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
1281
1282


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1283
1284
1285
1286
1287
1288
1289
1290
1291
if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}