main.py 42.6 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
Que Nguyen's avatar
Que Nguyen committed
11
import requests
12
import os, shutil, logging, re
mindspawn's avatar
mindspawn committed
13
from datetime import datetime
14
15

from pathlib import Path
16
from typing import List, Union, Sequence, Iterator, Any
Timothy J. Baek's avatar
Timothy J. Baek committed
17

18
from chromadb.utils.batch_utils import create_batches
19
from langchain_core.documents import Document
Timothy J. Baek's avatar
Timothy J. Baek committed
20

Timothy J. Baek's avatar
Timothy J. Baek committed
21
22
23
24
25
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
26
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
28
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
29
30
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
31
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
32
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
33
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
34
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
35
    YoutubeLoader,
mindspawn's avatar
mindspawn committed
36
    OutlookMessageLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
37
)
38
39
from langchain.text_splitter import RecursiveCharacterTextSplitter

40
41
42
43
44
import validators
import urllib.parse
import socket


45
46
from pydantic import BaseModel
from typing import Optional
47
import mimetypes
48
import uuid
49
50
import json

51
import sentence_transformers
52

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
53
from apps.webui.models.documents import (
54
55
56
57
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
58

59
from apps.rag.utils import (
60
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
61
62
63
64
65
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
66
)
Timothy J. Baek's avatar
Timothy J. Baek committed
67

Timothy J. Baek's avatar
Timothy J. Baek committed
68
69
70
71
72
73
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
74
from apps.rag.search.serply import search_serply
75
from apps.rag.search.duckduckgo import search_duckduckgo
76
from apps.rag.search.tavily import search_tavily
Timothy J. Baek's avatar
Timothy J. Baek committed
77

78
79
80
81
82
83
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
84
from utils.utils import get_current_user, get_admin_user
85

86
from config import (
87
    AppConfig,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
88
    ENV,
89
    SRC_LOG_LEVELS,
90
91
    UPLOAD_DIR,
    DOCS_DIR,
92
93
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
94
    RAG_EMBEDDING_ENGINE,
95
    RAG_EMBEDDING_MODEL,
96
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
97
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
98
    ENABLE_RAG_HYBRID_SEARCH,
99
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
100
    RAG_RERANKING_MODEL,
101
    PDF_EXTRACT_IMAGES,
102
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
103
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
106
    DEVICE_TYPE,
107
108
109
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
110
    RAG_TEMPLATE,
111
    ENABLE_RAG_LOCAL_WEB_FETCH,
112
    YOUTUBE_LOADER_LANGUAGE,
Timothy J. Baek's avatar
Timothy J. Baek committed
113
    ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
114
    RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
    SEARXNG_QUERY_URL,
    GOOGLE_PSE_API_KEY,
    GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
118
    BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
119
120
121
    SERPSTACK_API_KEY,
    SERPSTACK_HTTPS,
    SERPER_API_KEY,
122
    SERPLY_API_KEY,
123
    TAVILY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
124
    RAG_WEB_SEARCH_RESULT_COUNT,
125
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
126
    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
127
)
128

129
130
from constants import ERROR_MESSAGES

131
132
133
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
134
135
app = FastAPI()

136
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
137

138
139
140
141
142
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
143
144
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
145

146
147
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
148

149
150
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
151
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
152
153
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
154

155

156
157
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
158

159
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
160

Steven Kreitzer's avatar
Steven Kreitzer committed
161

162
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
163
164
165
app.state.YOUTUBE_LOADER_TRANSLATION = None


Timothy J. Baek's avatar
Timothy J. Baek committed
166
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
Timothy J. Baek's avatar
Timothy J. Baek committed
167
168
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE

Timothy J. Baek's avatar
Timothy J. Baek committed
169
170
171
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
Timothy J. Baek's avatar
Timothy J. Baek committed
172
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
173
174
175
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
176
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
177
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
180
181
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS


182
183
184
185
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
186
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
211
    app.state.config.RAG_EMBEDDING_MODEL,
212
213
214
215
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
216
    app.state.config.RAG_RERANKING_MODEL,
217
218
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
219

Timothy J. Baek's avatar
Timothy J. Baek committed
220
221

app.state.EMBEDDING_FUNCTION = get_embedding_function(
222
223
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
224
    app.state.sentence_transformer_ef,
225
226
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
227
    app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
228
229
)

Timothy J. Baek's avatar
Timothy J. Baek committed
230
231
origins = ["*"]

232

Timothy J. Baek's avatar
Timothy J. Baek committed
233
234
235
236
237
238
239
240
241
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
242
class CollectionNameForm(BaseModel):
243
244
245
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
246
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
247
248
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
249

250
251
252
253
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
254
255
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
256
257
    return {
        "status": True,
258
259
260
261
262
263
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
264
        "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
265
266
267
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
268
269
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
270
271
    return {
        "status": True,
272
273
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
274
        "openai_config": {
275
276
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
277
            "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
278
        },
279
280
281
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
282
283
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
284
285
    return {
        "status": True,
286
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
287
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
288
289


290
291
292
class OpenAIConfigForm(BaseModel):
    url: str
    key: str
293
    batch_size: Optional[int] = None
294
295


296
class EmbeddingModelUpdateForm(BaseModel):
297
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
298
    embedding_engine: str
299
300
301
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
302
303
@app.post("/embedding/update")
async def update_embedding_config(
304
305
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
306
    log.info(
307
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
308
    )
309
    try:
310
311
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
312

313
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
314
            if form_data.openai_config is not None:
315
316
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
317
318
319
320
321
                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
                    form_data.openai_config.batch_size
                    if form_data.openai_config.batch_size
                    else 1
                )
322

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
323
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
324

Timothy J. Baek's avatar
Timothy J. Baek committed
325
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
326
327
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
328
            app.state.sentence_transformer_ef,
329
330
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
331
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
332
333
        )

334
335
        return {
            "status": True,
336
337
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
338
            "openai_config": {
339
340
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
341
                "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
342
            },
343
344
345
346
347
348
349
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
350
351


Steven Kreitzer's avatar
Steven Kreitzer committed
352
353
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
354

Steven Kreitzer's avatar
Steven Kreitzer committed
355
356
357
358
359
360

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
361
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
362
363
    )
    try:
364
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
365

366
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
367
368
369

        return {
            "status": True,
370
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
371
372
373
374
375
376
377
378
379
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
380
381
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
382
383
    return {
        "status": True,
384
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
385
        "chunk": {
386
387
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
388
        },
389
        "youtube": {
390
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
391
392
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
393
        "web": {
Timothy J. Baek's avatar
Timothy J. Baek committed
394
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
395
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
396
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
397
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
398
399
400
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
401
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
402
403
404
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
405
                "serply_api_key": app.state.config.SERPLY_API_KEY,
406
                "tavily_api_key": app.state.config.TAVILY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
407
408
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
Timothy J. Baek's avatar
Timothy J. Baek committed
409
            },
Timothy J. Baek's avatar
Timothy J. Baek committed
410
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
411
412
413
414
415
416
417
418
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


419
420
421
422
423
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
424
class WebSearchConfig(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
425
    enabled: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
426
    engine: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
427
428
429
    searxng_query_url: Optional[str] = None
    google_pse_api_key: Optional[str] = None
    google_pse_engine_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
430
    brave_search_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
431
432
433
    serpstack_api_key: Optional[str] = None
    serpstack_https: Optional[bool] = None
    serper_api_key: Optional[str] = None
434
    serply_api_key: Optional[str] = None
435
    tavily_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
436
437
438
439
    result_count: Optional[int] = None
    concurrent_requests: Optional[int] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
440
441
442
443
444
class WebConfig(BaseModel):
    search: WebSearchConfig
    web_loader_ssl_verification: Optional[bool] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
445
class ConfigUpdateForm(BaseModel):
446
447
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
448
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
449
    web: Optional[WebConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
450
451
452
453


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
454
    app.state.config.PDF_EXTRACT_IMAGES = (
455
        form_data.pdf_extract_images
456
457
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
458
459
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
460
461
462
    if form_data.chunk is not None:
        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
463

Timothy J. Baek's avatar
Timothy J. Baek committed
464
465
466
    if form_data.youtube is not None:
        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
467

Timothy J. Baek's avatar
Timothy J. Baek committed
468
469
470
471
    if form_data.web is not None:
        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
            form_data.web.web_loader_ssl_verification
        )
472

Timothy J. Baek's avatar
Timothy J. Baek committed
473
        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
Timothy J. Baek's avatar
Timothy J. Baek committed
474
475
476
477
478
479
480
481
482
483
484
485
        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
        app.state.config.GOOGLE_PSE_ENGINE_ID = (
            form_data.web.search.google_pse_engine_id
        )
        app.state.config.BRAVE_SEARCH_API_KEY = (
            form_data.web.search.brave_search_api_key
        )
        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
486
        app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
487
        app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
Timothy J. Baek's avatar
Timothy J. Baek committed
488
489
490
491
        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
            form_data.web.search.concurrent_requests
        )
492

Timothy J. Baek's avatar
Timothy J. Baek committed
493
494
    return {
        "status": True,
495
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
496
        "chunk": {
497
498
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
499
        },
500
        "youtube": {
501
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
502
503
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
504
505
506
        "web": {
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
507
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
508
509
510
511
512
513
514
515
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
516
                "serply_api_key": app.state.config.SERPLY_API_KEY,
517
                "tavily_api_key": app.state.config.TAVILY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
518
519
520
521
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
            },
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
522
    }
523
524


Timothy J. Baek's avatar
Timothy J. Baek committed
525
526
527
528
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
529
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
530
531
532
    }


533
534
535
536
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
537
538
539
540
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
541
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
542
543


544
545
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
546
    r: Optional[float] = None
547
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
548
    hybrid: Optional[bool] = None
549
550
551
552
553
554


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
555
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
556
        form_data.template if form_data.template else RAG_TEMPLATE
557
    )
558
559
560
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
561
        form_data.hybrid if form_data.hybrid else False
562
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
563
564
    return {
        "status": True,
565
566
567
568
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
569
    }
570
571


572
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
573
574
    collection_name: str
    query: str
575
    k: Optional[int] = None
576
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
577
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
578
579


580
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
581
def query_doc_handler(
582
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
583
584
    user=Depends(get_current_user),
):
585
    try:
586
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
587
588
589
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
590
                embedding_function=app.state.EMBEDDING_FUNCTION,
591
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
592
                reranking_function=app.state.sentence_transformer_rf,
593
                r=(
594
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
595
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
596
597
598
599
600
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
601
                embedding_function=app.state.EMBEDDING_FUNCTION,
602
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
603
            )
604
    except Exception as e:
605
        log.exception(e)
606
607
608
609
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
610
611


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
612
613
614
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
615
    k: Optional[int] = None
616
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
617
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
618
619


620
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
621
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
622
623
624
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
625
    try:
626
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
627
628
629
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
630
                embedding_function=app.state.EMBEDDING_FUNCTION,
631
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
632
                reranking_function=app.state.sentence_transformer_rf,
633
                r=(
634
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
635
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
636
637
638
639
640
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
641
                embedding_function=app.state.EMBEDDING_FUNCTION,
642
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
643
            )
644

645
646
647
648
649
650
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
651
652


Timothy J. Baek's avatar
Timothy J. Baek committed
653
654
655
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
656
657
658
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
659
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
660
661
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


682
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
683
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
684
685
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
686
        loader = get_web_loader(
687
            form_data.url,
688
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
689
        )
690
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
691
692
693
694
695

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

696
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
697
698
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
699
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
700
701
            "filename": form_data.url,
        }
702
    except Exception as e:
703
        log.exception(e)
704
705
706
707
708
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

709

710
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
711
    # Check if the URL is valid
712
    if not validate_url(url):
713
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
714
    return SafeWebBaseLoader(
715
716
717
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
718
        continue_on_failure=True,
719
    )
720
721


722
723
724
725
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
726
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
727
728
729
730
731
732
733
734
735
736
737
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
738
739
740
741
742
743
744
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False

Timothy J. Baek's avatar
Timothy J. Baek committed
745

Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
746
747
748
749
750
751
752
753
754
755
756
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


Timothy J. Baek's avatar
Timothy J. Baek committed
757
758
759
760
761
762
763
764
def search_web(engine: str, query: str) -> list[SearchResult]:
    """Search the web using a search engine and return the results as a list of SearchResult objects.
    Will look for a search engine API key in environment variables in the following order:
    - SEARXNG_QUERY_URL
    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
    - BRAVE_SEARCH_API_KEY
    - SERPSTACK_API_KEY
    - SERPER_API_KEY
765
    - SERPLY_API_KEY
766
    - TAVILY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
767
768
769
770
771
772
773
    Args:
        query (str): The query to search for
    """

    # TODO: add playwright to search the web
    if engine == "searxng":
        if app.state.config.SEARXNG_QUERY_URL:
Timothy J. Baek's avatar
Timothy J. Baek committed
774
775
776
777
778
            return search_searxng(
                app.state.config.SEARXNG_QUERY_URL,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
779
780
781
782
783
784
785
786
787
788
789
        else:
            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
    elif engine == "google_pse":
        if (
            app.state.config.GOOGLE_PSE_API_KEY
            and app.state.config.GOOGLE_PSE_ENGINE_ID
        ):
            return search_google_pse(
                app.state.config.GOOGLE_PSE_API_KEY,
                app.state.config.GOOGLE_PSE_ENGINE_ID,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
790
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
791
792
793
794
795
796
797
            )
        else:
            raise Exception(
                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
            )
    elif engine == "brave":
        if app.state.config.BRAVE_SEARCH_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
798
799
800
801
802
            return search_brave(
                app.state.config.BRAVE_SEARCH_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
803
804
805
806
807
808
809
        else:
            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
    elif engine == "serpstack":
        if app.state.config.SERPSTACK_API_KEY:
            return search_serpstack(
                app.state.config.SERPSTACK_API_KEY,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
810
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
811
812
813
814
815
816
                https_enabled=app.state.config.SERPSTACK_HTTPS,
            )
        else:
            raise Exception("No SERPSTACK_API_KEY found in environment variables")
    elif engine == "serper":
        if app.state.config.SERPER_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
817
818
819
820
821
            return search_serper(
                app.state.config.SERPER_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
822
823
        else:
            raise Exception("No SERPER_API_KEY found in environment variables")
824
825
826
827
828
829
830
831
832
    elif engine == "serply":
        if app.state.config.SERPLY_API_KEY:
            return search_serply(
                app.state.config.SERPLY_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
        else:
            raise Exception("No SERPLY_API_KEY found in environment variables")
833
834
    elif engine == "duckduckgo":
        return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
835
836
837
838
839
840
841
842
843
    elif engine == "tavily":
        if app.state.config.TAVILY_API_KEY:
            return search_tavily(
                app.state.config.TAVILY_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
        else:
            raise Exception("No TAVILY_API_KEY found in environment variables")
Timothy J. Baek's avatar
Timothy J. Baek committed
844
845
846
847
    else:
        raise Exception("No search engine API key found in environment variables")


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
848
849
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
850
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
851
852
853
        logging.info(
            f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
854
855
856
857
858
859
860
861
862
863
864
865
866
        web_results = search_web(
            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
        )
    except Exception as e:
        log.exception(e)

        print(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
        )

    try:
867
868
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
869
        data = loader.load()
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


889
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
890

891
    text_splitter = RecursiveCharacterTextSplitter(
892
893
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
894
895
        add_start_index=True,
    )
896

897
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
898
899

    if len(docs) > 0:
900
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
901
902
903
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
904
905
906


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
907
    text, metadata, collection_name, overwrite: bool = False
908
909
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
910
911
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
912
913
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
914
    docs = text_splitter.create_documents([text], metadatas=[metadata])
915
916
917
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
918
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
919
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
920

921
922
923
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

mindspawn's avatar
mindspawn committed
924
925
926
927
928
929
930
    # ChromaDB does not like datetime formats
    # for meta-data so convert them to string.
    for metadata in metadatas:
        for key, value in metadata.items():
            if isinstance(value, datetime):
                metadata[key] = str(value)

931
932
933
934
    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
935
                    log.info(f"deleting existing collection {collection_name}")
936
937
                    CHROMA_CLIENT.delete_collection(name=collection_name)

938
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
939

Timothy J. Baek's avatar
Timothy J. Baek committed
940
        embedding_func = get_embedding_function(
941
942
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
943
            app.state.sentence_transformer_ef,
944
945
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
946
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Steven Kreitzer's avatar
Steven Kreitzer committed
947
948
949
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
950
        embeddings = embedding_func(embedding_texts)
951
952
953

        for batch in create_batches(
            api=CHROMA_CLIENT,
954
            ids=[str(uuid.uuid4()) for _ in texts],
955
956
957
958
959
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
960

961
        return True
962
    except Exception as e:
963
        log.exception(e)
964
965
966
967
968
969
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


970
971
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
mindspawn's avatar
mindspawn committed
1017
        "msg",
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1018
1019
1020
    ]

    if file_ext == "pdf":
1021
        loader = PyPDFLoader(
1022
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
1023
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1024
1025
1026
1027
1028
1029
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
1030
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
1031
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1032
1033
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
1034
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1035
1036
        loader = UnstructuredEPubLoader(file_path)
    elif (
1037
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1038
1039
1040
1041
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
1042
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1043
1044
1045
1046
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1047
1048
1049
1050
1051
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
mindspawn's avatar
mindspawn committed
1052
1053
    elif file_ext == "msg":
        loader = OutlookMessageLoader(file_path)
1054
1055
1056
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
1057
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1058
    else:
1059
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1060
1061
1062
1063
1064
        known_type = False

    return loader, known_type


1065
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
1066
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
1067
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
1068
1069
1070
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
1071
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
1072

1073
    log.info(f"file.content_type: {file.content_type}")
1074
    try:
1075
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
1076
        filename = os.path.basename(unsanitized_filename)
1077

Timothy J. Baek's avatar
Timothy J. Baek committed
1078
        file_path = f"{UPLOAD_DIR}/{filename}"
1079

1080
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
1081
        with open(file_path, "wb") as f:
1082
1083
1084
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1085
1086
1087
1088
1089
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1090
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1091
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
1104
1105
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
1106
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
1107
            )
1108
    except Exception as e:
1109
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
1120
1121


1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
1138
1139
1140
1141
1142
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


1153
1154
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
1155
1156
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
1157
1158
1159
1160
1161
1162
1163
1164
1165
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1166
1167
1168
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
1169
1170
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
1195
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
1196
1197
1198
1199
1200
1201
1202
1203
1204
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
1205
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
1206
                    pass
1207

1208
        except Exception as e:
1209
            log.exception(e)
1210
1211
1212
1213

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1214
@app.get("/reset/db")
1215
1216
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
1217
1218


Timothy J. Baek's avatar
Timothy J. Baek committed
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
@app.get("/reset/uploads")
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    try:
        # Check if the directory exists
        if os.path.exists(folder):
            # Iterate over all the files and directories in the specified directory
            for filename in os.listdir(folder):
                file_path = os.path.join(folder, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)  # Remove the file or link
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)  # Remove the directory
                except Exception as e:
                    print(f"Failed to delete {file_path}. Reason: {e}")
        else:
            print(f"The directory {folder} does not exist")
    except Exception as e:
        print(f"Failed to process the directory {folder}. Reason: {e}")

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1243
@app.get("/reset")
1244
1245
1246
1247
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
1248
        try:
1249
1250
1251
1252
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1253
        except Exception as e:
1254
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1255

1256
1257
1258
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1259
        log.exception(e)
1260
1261

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1262

Timothy J. Baek's avatar
Timothy J. Baek committed
1263

1264
1265
class SafeWebBaseLoader(WebBaseLoader):
    """WebBaseLoader with enhanced error handling for URLs."""
Timothy J. Baek's avatar
Timothy J. Baek committed
1266

1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
    def lazy_load(self) -> Iterator[Document]:
        """Lazy load text from the url(s) in web_path with error handling."""
        for path in self.web_paths:
            try:
                soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
                text = soup.get_text(**self.bs_get_text_kwargs)

                # Build metadata
                metadata = {"source": path}
                if title := soup.find("title"):
                    metadata["title"] = title.get_text()
                if description := soup.find("meta", attrs={"name": "description"}):
Timothy J. Baek's avatar
Timothy J. Baek committed
1279
1280
1281
                    metadata["description"] = description.get(
                        "content", "No description found."
                    )
1282
1283
                if html := soup.find("html"):
                    metadata["language"] = html.get("lang", "No language found.")
Timothy J. Baek's avatar
Timothy J. Baek committed
1284

1285
1286
1287
1288
                yield Document(page_content=text, metadata=metadata)
            except Exception as e:
                # Log the error and continue with the next URL
                log.error(f"Error loading {path}: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
1289
1290


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1291
1292
1293
1294
1295
1296
1297
1298
1299
if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}