main.py 44.7 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
Que Nguyen's avatar
Que Nguyen committed
11
import requests
12
import os, shutil, logging, re
mindspawn's avatar
mindspawn committed
13
from datetime import datetime
14
15

from pathlib import Path
16
from typing import List, Union, Sequence, Iterator, Any
Timothy J. Baek's avatar
Timothy J. Baek committed
17

18
from chromadb.utils.batch_utils import create_batches
19
from langchain_core.documents import Document
Timothy J. Baek's avatar
Timothy J. Baek committed
20

Timothy J. Baek's avatar
Timothy J. Baek committed
21
22
23
24
25
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
26
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
28
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
29
30
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
31
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
32
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
33
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
34
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
35
    YoutubeLoader,
mindspawn's avatar
mindspawn committed
36
    OutlookMessageLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
37
)
38
39
from langchain.text_splitter import RecursiveCharacterTextSplitter

40
41
42
43
44
import validators
import urllib.parse
import socket


45
46
from pydantic import BaseModel
from typing import Optional
47
import mimetypes
48
import uuid
49
50
import json

51
import sentence_transformers
52

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
53
from apps.webui.models.documents import (
54
55
56
57
    Documents,
    DocumentForm,
    DocumentResponse,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60
from apps.webui.models.files import (
    Files,
)
Jannik Streidl's avatar
Jannik Streidl committed
61

62
from apps.rag.utils import (
63
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
64
65
66
67
68
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
69
)
Timothy J. Baek's avatar
Timothy J. Baek committed
70

Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
74
75
76
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
77
from apps.rag.search.serply import search_serply
78
from apps.rag.search.duckduckgo import search_duckduckgo
79
from apps.rag.search.tavily import search_tavily
Timothy J. Baek's avatar
Timothy J. Baek committed
80

81
82
83
84
85
86
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
87
from utils.utils import get_current_user, get_admin_user
88

89
from config import (
90
    AppConfig,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
91
    ENV,
92
    SRC_LOG_LEVELS,
93
94
    UPLOAD_DIR,
    DOCS_DIR,
95
96
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
97
    RAG_EMBEDDING_ENGINE,
98
    RAG_EMBEDDING_MODEL,
99
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
100
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
101
    ENABLE_RAG_HYBRID_SEARCH,
102
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
103
    RAG_RERANKING_MODEL,
104
    PDF_EXTRACT_IMAGES,
105
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
106
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
107
108
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
109
    DEVICE_TYPE,
110
111
112
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
113
    RAG_TEMPLATE,
114
    ENABLE_RAG_LOCAL_WEB_FETCH,
115
    YOUTUBE_LOADER_LANGUAGE,
Timothy J. Baek's avatar
Timothy J. Baek committed
116
    ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
117
    RAG_WEB_SEARCH_ENGINE,
Que Nguyen's avatar
Que Nguyen committed
118
    RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
119
120
121
    SEARXNG_QUERY_URL,
    GOOGLE_PSE_API_KEY,
    GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
122
    BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
123
124
125
    SERPSTACK_API_KEY,
    SERPSTACK_HTTPS,
    SERPER_API_KEY,
126
    SERPLY_API_KEY,
127
    TAVILY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
128
    RAG_WEB_SEARCH_RESULT_COUNT,
129
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
130
    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
131
)
132

133
134
from constants import ERROR_MESSAGES

135
136
137
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
138
139
app = FastAPI()

140
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
141

142
143
144
145
146
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
147
148
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
149

150
151
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
152

153
154
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
155
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
156
157
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
158

159

160
161
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
162

163
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
164

Steven Kreitzer's avatar
Steven Kreitzer committed
165

166
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
167
168
169
app.state.YOUTUBE_LOADER_TRANSLATION = None


Timothy J. Baek's avatar
Timothy J. Baek committed
170
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
Timothy J. Baek's avatar
Timothy J. Baek committed
171
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
Que Nguyen's avatar
Que Nguyen committed
172
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
173

Timothy J. Baek's avatar
Timothy J. Baek committed
174
175
176
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
Timothy J. Baek's avatar
Timothy J. Baek committed
177
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
180
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
181
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
182
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
185
186
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS


187
188
189
190
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
191
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
216
    app.state.config.RAG_EMBEDDING_MODEL,
217
218
219
220
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
221
    app.state.config.RAG_RERANKING_MODEL,
222
223
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
224

Timothy J. Baek's avatar
Timothy J. Baek committed
225
226

app.state.EMBEDDING_FUNCTION = get_embedding_function(
227
228
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
229
    app.state.sentence_transformer_ef,
230
231
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
232
    app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
233
234
)

Timothy J. Baek's avatar
Timothy J. Baek committed
235
236
origins = ["*"]

237

Timothy J. Baek's avatar
Timothy J. Baek committed
238
239
240
241
242
243
244
245
246
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
247
class CollectionNameForm(BaseModel):
248
249
250
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
251
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
252
253
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
254

255
256
257
258
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
259
260
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
261
262
    return {
        "status": True,
263
264
265
266
267
268
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
269
        "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
270
271
272
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
273
274
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
275
276
    return {
        "status": True,
277
278
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
279
        "openai_config": {
280
281
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
282
            "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
283
        },
284
285
286
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
287
288
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
289
290
    return {
        "status": True,
291
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
292
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
293
294


295
296
297
class OpenAIConfigForm(BaseModel):
    url: str
    key: str
298
    batch_size: Optional[int] = None
299
300


301
class EmbeddingModelUpdateForm(BaseModel):
302
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
303
    embedding_engine: str
304
305
306
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
307
308
@app.post("/embedding/update")
async def update_embedding_config(
309
310
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
311
    log.info(
312
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
313
    )
314
    try:
315
316
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
317

318
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
319
            if form_data.openai_config is not None:
320
321
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
322
323
324
325
326
                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
                    form_data.openai_config.batch_size
                    if form_data.openai_config.batch_size
                    else 1
                )
327

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
328
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
329

Timothy J. Baek's avatar
Timothy J. Baek committed
330
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
331
332
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
333
            app.state.sentence_transformer_ef,
334
335
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
336
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
337
338
        )

339
340
        return {
            "status": True,
341
342
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
343
            "openai_config": {
344
345
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
346
                "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
347
            },
348
349
350
351
352
353
354
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
355
356


Steven Kreitzer's avatar
Steven Kreitzer committed
357
358
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
359

Steven Kreitzer's avatar
Steven Kreitzer committed
360
361
362
363
364
365

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
366
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
367
368
    )
    try:
369
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
370

371
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
372
373
374

        return {
            "status": True,
375
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
376
377
378
379
380
381
382
383
384
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
385
386
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
387
388
    return {
        "status": True,
389
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
390
        "chunk": {
391
392
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
393
        },
394
        "youtube": {
395
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
396
397
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
398
        "web": {
Timothy J. Baek's avatar
Timothy J. Baek committed
399
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
400
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
401
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
402
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
403
404
405
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
406
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
407
408
409
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
410
                "serply_api_key": app.state.config.SERPLY_API_KEY,
411
                "tavily_api_key": app.state.config.TAVILY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
412
413
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
Timothy J. Baek's avatar
Timothy J. Baek committed
414
            },
Timothy J. Baek's avatar
Timothy J. Baek committed
415
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
416
417
418
419
420
421
422
423
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


424
425
426
427
428
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
429
class WebSearchConfig(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
430
    enabled: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
431
    engine: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
432
433
434
    searxng_query_url: Optional[str] = None
    google_pse_api_key: Optional[str] = None
    google_pse_engine_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
435
    brave_search_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
436
437
438
    serpstack_api_key: Optional[str] = None
    serpstack_https: Optional[bool] = None
    serper_api_key: Optional[str] = None
439
    serply_api_key: Optional[str] = None
440
    tavily_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
441
442
443
444
    result_count: Optional[int] = None
    concurrent_requests: Optional[int] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
445
446
447
448
449
class WebConfig(BaseModel):
    search: WebSearchConfig
    web_loader_ssl_verification: Optional[bool] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
450
class ConfigUpdateForm(BaseModel):
451
452
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
453
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
454
    web: Optional[WebConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
455
456
457
458


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
459
    app.state.config.PDF_EXTRACT_IMAGES = (
460
        form_data.pdf_extract_images
461
462
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
463
464
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
465
466
467
    if form_data.chunk is not None:
        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
468

Timothy J. Baek's avatar
Timothy J. Baek committed
469
470
471
    if form_data.youtube is not None:
        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
472

Timothy J. Baek's avatar
Timothy J. Baek committed
473
474
475
476
    if form_data.web is not None:
        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
            form_data.web.web_loader_ssl_verification
        )
477

Timothy J. Baek's avatar
Timothy J. Baek committed
478
        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
Timothy J. Baek's avatar
Timothy J. Baek committed
479
480
481
482
483
484
485
486
487
488
489
490
        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
        app.state.config.GOOGLE_PSE_ENGINE_ID = (
            form_data.web.search.google_pse_engine_id
        )
        app.state.config.BRAVE_SEARCH_API_KEY = (
            form_data.web.search.brave_search_api_key
        )
        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
491
        app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
492
        app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
Timothy J. Baek's avatar
Timothy J. Baek committed
493
494
495
496
        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
            form_data.web.search.concurrent_requests
        )
497

Timothy J. Baek's avatar
Timothy J. Baek committed
498
499
    return {
        "status": True,
500
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
501
        "chunk": {
502
503
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
504
        },
505
        "youtube": {
506
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
507
508
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
509
510
511
        "web": {
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
512
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
513
514
515
516
517
518
519
520
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
521
                "serply_api_key": app.state.config.SERPLY_API_KEY,
522
                "tavily_api_key": app.state.config.TAVILY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
523
524
525
526
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
            },
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
527
    }
528
529


Timothy J. Baek's avatar
Timothy J. Baek committed
530
531
532
533
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
534
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
535
536
537
    }


538
539
540
541
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
542
543
544
545
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
546
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
547
548


549
550
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
551
    r: Optional[float] = None
552
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
553
    hybrid: Optional[bool] = None
554
555
556
557
558
559


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
560
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
561
        form_data.template if form_data.template else RAG_TEMPLATE
562
    )
563
564
565
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
566
        form_data.hybrid if form_data.hybrid else False
567
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
568
569
    return {
        "status": True,
570
571
572
573
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
574
    }
575
576


577
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
578
579
    collection_name: str
    query: str
580
    k: Optional[int] = None
581
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
582
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
583
584


585
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
586
def query_doc_handler(
587
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
588
589
    user=Depends(get_current_user),
):
590
    try:
591
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
592
593
594
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
595
                embedding_function=app.state.EMBEDDING_FUNCTION,
596
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
597
                reranking_function=app.state.sentence_transformer_rf,
598
                r=(
599
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
600
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
601
602
603
604
605
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
606
                embedding_function=app.state.EMBEDDING_FUNCTION,
607
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
608
            )
609
    except Exception as e:
610
        log.exception(e)
611
612
613
614
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
615
616


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
617
618
619
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
620
    k: Optional[int] = None
621
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
622
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
623
624


625
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
626
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
627
628
629
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
630
    try:
631
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
632
633
634
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
635
                embedding_function=app.state.EMBEDDING_FUNCTION,
636
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
637
                reranking_function=app.state.sentence_transformer_rf,
638
                r=(
639
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
640
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
641
642
643
644
645
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
646
                embedding_function=app.state.EMBEDDING_FUNCTION,
647
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
648
            )
649

650
651
652
653
654
655
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
656
657


Timothy J. Baek's avatar
Timothy J. Baek committed
658
659
660
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
661
662
663
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
664
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
665
666
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


687
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
688
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
689
690
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
691
        loader = get_web_loader(
692
            form_data.url,
693
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
694
        )
695
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
696
697
698
699
700

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

701
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
702
703
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
704
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
705
706
            "filename": form_data.url,
        }
707
    except Exception as e:
708
        log.exception(e)
709
710
711
712
713
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

714

715
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
716
    # Check if the URL is valid
717
    if not validate_url(url):
718
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
719
    return SafeWebBaseLoader(
720
721
722
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
723
        continue_on_failure=True,
724
    )
725
726


727
728
729
730
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
731
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
732
733
734
735
736
737
738
739
740
741
742
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
743
744
745
746
747
748
749
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False

Timothy J. Baek's avatar
Timothy J. Baek committed
750

Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
751
752
753
754
755
756
757
758
759
760
761
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


Timothy J. Baek's avatar
Timothy J. Baek committed
762
763
764
765
766
767
768
769
def search_web(engine: str, query: str) -> list[SearchResult]:
    """Search the web using a search engine and return the results as a list of SearchResult objects.
    Will look for a search engine API key in environment variables in the following order:
    - SEARXNG_QUERY_URL
    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
    - BRAVE_SEARCH_API_KEY
    - SERPSTACK_API_KEY
    - SERPER_API_KEY
770
    - SERPLY_API_KEY
771
    - TAVILY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
772
773
774
775
776
777
778
    Args:
        query (str): The query to search for
    """

    # TODO: add playwright to search the web
    if engine == "searxng":
        if app.state.config.SEARXNG_QUERY_URL:
Timothy J. Baek's avatar
Timothy J. Baek committed
779
780
781
782
            return search_searxng(
                app.state.config.SEARXNG_QUERY_URL,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
783
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
784
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
785
786
787
788
789
790
791
792
793
794
795
        else:
            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
    elif engine == "google_pse":
        if (
            app.state.config.GOOGLE_PSE_API_KEY
            and app.state.config.GOOGLE_PSE_ENGINE_ID
        ):
            return search_google_pse(
                app.state.config.GOOGLE_PSE_API_KEY,
                app.state.config.GOOGLE_PSE_ENGINE_ID,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
796
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
797
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
798
799
800
801
802
803
804
            )
        else:
            raise Exception(
                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
            )
    elif engine == "brave":
        if app.state.config.BRAVE_SEARCH_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
805
806
807
808
            return search_brave(
                app.state.config.BRAVE_SEARCH_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
809
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
810
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
811
812
813
814
815
816
817
        else:
            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
    elif engine == "serpstack":
        if app.state.config.SERPSTACK_API_KEY:
            return search_serpstack(
                app.state.config.SERPSTACK_API_KEY,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
818
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Que Nguyen's avatar
Que Nguyen committed
819
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
820
821
822
823
824
825
                https_enabled=app.state.config.SERPSTACK_HTTPS,
            )
        else:
            raise Exception("No SERPSTACK_API_KEY found in environment variables")
    elif engine == "serper":
        if app.state.config.SERPER_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
826
827
828
829
            return search_serper(
                app.state.config.SERPER_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
830
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
831
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
832
833
        else:
            raise Exception("No SERPER_API_KEY found in environment variables")
834
835
836
837
838
839
    elif engine == "serply":
        if app.state.config.SERPLY_API_KEY:
            return search_serply(
                app.state.config.SERPLY_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
840
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
841
842
843
            )
        else:
            raise Exception("No SERPLY_API_KEY found in environment variables")
844
    elif engine == "duckduckgo":
Timothy J. Baek's avatar
Timothy J. Baek committed
845
846
847
848
849
        return search_duckduckgo(
            query,
            app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
        )
850
851
852
853
854
855
856
857
858
    elif engine == "tavily":
        if app.state.config.TAVILY_API_KEY:
            return search_tavily(
                app.state.config.TAVILY_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
        else:
            raise Exception("No TAVILY_API_KEY found in environment variables")
Timothy J. Baek's avatar
Timothy J. Baek committed
859
860
861
862
    else:
        raise Exception("No search engine API key found in environment variables")


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
863
864
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
865
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
866
867
868
        logging.info(
            f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
869
870
871
872
873
874
875
876
877
878
879
880
881
        web_results = search_web(
            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
        )
    except Exception as e:
        log.exception(e)

        print(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
        )

    try:
882
883
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
884
        data = loader.load()
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


904
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
905

906
    text_splitter = RecursiveCharacterTextSplitter(
907
908
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
909
910
        add_start_index=True,
    )
911

912
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
913
914

    if len(docs) > 0:
915
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
916
917
918
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
919
920
921


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
922
    text, metadata, collection_name, overwrite: bool = False
923
924
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
925
926
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
927
928
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
929
    docs = text_splitter.create_documents([text], metadatas=[metadata])
930
931
932
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
933
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
934
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
935

936
937
938
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

mindspawn's avatar
mindspawn committed
939
940
941
942
943
944
945
    # ChromaDB does not like datetime formats
    # for meta-data so convert them to string.
    for metadata in metadatas:
        for key, value in metadata.items():
            if isinstance(value, datetime):
                metadata[key] = str(value)

946
947
948
949
    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
950
                    log.info(f"deleting existing collection {collection_name}")
951
952
                    CHROMA_CLIENT.delete_collection(name=collection_name)

953
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
954

Timothy J. Baek's avatar
Timothy J. Baek committed
955
        embedding_func = get_embedding_function(
956
957
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
958
            app.state.sentence_transformer_ef,
959
960
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
961
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Steven Kreitzer's avatar
Steven Kreitzer committed
962
963
964
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
965
        embeddings = embedding_func(embedding_texts)
966
967
968

        for batch in create_batches(
            api=CHROMA_CLIENT,
969
            ids=[str(uuid.uuid4()) for _ in texts],
970
971
972
973
974
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
975

976
        return True
977
    except Exception as e:
978
        log.exception(e)
979
980
981
982
983
984
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


985
986
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
mindspawn's avatar
mindspawn committed
1032
        "msg",
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1033
1034
1035
    ]

    if file_ext == "pdf":
1036
        loader = PyPDFLoader(
1037
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
1038
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1039
1040
1041
1042
1043
1044
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
1045
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
1046
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1047
1048
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
1049
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1050
1051
        loader = UnstructuredEPubLoader(file_path)
    elif (
1052
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1053
1054
1055
1056
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
1057
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1058
1059
1060
1061
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1062
1063
1064
1065
1066
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
mindspawn's avatar
mindspawn committed
1067
1068
    elif file_ext == "msg":
        loader = OutlookMessageLoader(file_path)
1069
1070
1071
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
1072
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1073
    else:
1074
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1075
1076
1077
1078
1079
        known_type = False

    return loader, known_type


1080
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
1081
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
1082
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
1083
1084
1085
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
1086
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
1087

1088
    log.info(f"file.content_type: {file.content_type}")
1089
    try:
1090
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
1091
        filename = os.path.basename(unsanitized_filename)
1092

Timothy J. Baek's avatar
Timothy J. Baek committed
1093
        file_path = f"{UPLOAD_DIR}/{filename}"
1094

1095
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
1096
        with open(file_path, "wb") as f:
1097
1098
1099
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1100
1101
1102
1103
1104
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1105
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1106
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
1119
1120
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
1121
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
1122
            )
1123
    except Exception as e:
1124
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
1135
1136


Timothy J. Baek's avatar
Timothy J. Baek committed
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
class ProcessDocForm(BaseModel):
    file_id: str


@app.post("/process/doc")
def process_doc(
    form_data: ProcessDocForm,
    user=Depends(get_current_user),
):
    try:
        file = Files.get_file_by_id(form_data.file_id)
        file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")

        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

        loader, known_type = get_loader(
            file.filename, file.meta.get("content_type"), file_path
        )
        data = loader.load()

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "known_type": known_type,
                }
        except Exception as e:
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail=e,
            )
    except Exception as e:
        log.exception(e)
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )


1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
1204
1205
1206
1207
1208
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


1219
1220
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
1221
1222
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
1223
1224
1225
1226
1227
1228
1229
1230
1231
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1232
1233
1234
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
1235
1236
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
1261
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
1262
1263
1264
1265
1266
1267
1268
1269
1270
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
1271
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
1272
                    pass
1273

1274
        except Exception as e:
1275
            log.exception(e)
1276
1277
1278
1279

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1280
@app.get("/reset/db")
1281
1282
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
1283
1284


Timothy J. Baek's avatar
Timothy J. Baek committed
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
@app.get("/reset/uploads")
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    try:
        # Check if the directory exists
        if os.path.exists(folder):
            # Iterate over all the files and directories in the specified directory
            for filename in os.listdir(folder):
                file_path = os.path.join(folder, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)  # Remove the file or link
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)  # Remove the directory
                except Exception as e:
                    print(f"Failed to delete {file_path}. Reason: {e}")
        else:
            print(f"The directory {folder} does not exist")
    except Exception as e:
        print(f"Failed to process the directory {folder}. Reason: {e}")

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1309
@app.get("/reset")
1310
1311
1312
1313
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
1314
        try:
1315
1316
1317
1318
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1319
        except Exception as e:
1320
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1321

1322
1323
1324
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1325
        log.exception(e)
1326
1327

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1328

Timothy J. Baek's avatar
Timothy J. Baek committed
1329

1330
1331
class SafeWebBaseLoader(WebBaseLoader):
    """WebBaseLoader with enhanced error handling for URLs."""
Timothy J. Baek's avatar
Timothy J. Baek committed
1332

1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
    def lazy_load(self) -> Iterator[Document]:
        """Lazy load text from the url(s) in web_path with error handling."""
        for path in self.web_paths:
            try:
                soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
                text = soup.get_text(**self.bs_get_text_kwargs)

                # Build metadata
                metadata = {"source": path}
                if title := soup.find("title"):
                    metadata["title"] = title.get_text()
                if description := soup.find("meta", attrs={"name": "description"}):
Timothy J. Baek's avatar
Timothy J. Baek committed
1345
1346
1347
                    metadata["description"] = description.get(
                        "content", "No description found."
                    )
1348
1349
                if html := soup.find("html"):
                    metadata["language"] = html.get("lang", "No language found.")
Timothy J. Baek's avatar
Timothy J. Baek committed
1350

1351
1352
1353
1354
                yield Document(page_content=text, metadata=metadata)
            except Exception as e:
                # Log the error and continue with the next URL
                log.error(f"Error loading {path}: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
1355
1356


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1357
1358
1359
1360
1361
1362
1363
1364
1365
if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}