main.py 45 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
Que Nguyen's avatar
Que Nguyen committed
11
import requests
12
import os, shutil, logging, re
mindspawn's avatar
mindspawn committed
13
from datetime import datetime
14
15

from pathlib import Path
16
from typing import List, Union, Sequence, Iterator, Any
Timothy J. Baek's avatar
Timothy J. Baek committed
17

18
from chromadb.utils.batch_utils import create_batches
19
from langchain_core.documents import Document
Timothy J. Baek's avatar
Timothy J. Baek committed
20

Timothy J. Baek's avatar
Timothy J. Baek committed
21
22
23
24
25
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
26
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
28
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
29
30
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
31
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
32
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
33
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
34
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
35
    YoutubeLoader,
mindspawn's avatar
mindspawn committed
36
    OutlookMessageLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
37
)
38
39
from langchain.text_splitter import RecursiveCharacterTextSplitter

40
41
42
43
44
import validators
import urllib.parse
import socket


45
46
from pydantic import BaseModel
from typing import Optional
47
import mimetypes
48
import uuid
49
50
import json

51
import sentence_transformers
52

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
53
from apps.webui.models.documents import (
54
55
56
57
    Documents,
    DocumentForm,
    DocumentResponse,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60
from apps.webui.models.files import (
    Files,
)
Jannik Streidl's avatar
Jannik Streidl committed
61

62
from apps.rag.utils import (
63
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
64
65
66
67
68
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
69
)
Timothy J. Baek's avatar
Timothy J. Baek committed
70

Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
74
75
76
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
77
from apps.rag.search.serply import search_serply
78
from apps.rag.search.duckduckgo import search_duckduckgo
79
from apps.rag.search.tavily import search_tavily
80
from apps.rag.search.jina_search import search_jina
Timothy J. Baek's avatar
Timothy J. Baek committed
81

82
83
84
85
86
87
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
88
from utils.utils import get_current_user, get_admin_user
89

90
from config import (
91
    AppConfig,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
92
    ENV,
93
    SRC_LOG_LEVELS,
94
95
    UPLOAD_DIR,
    DOCS_DIR,
96
97
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
98
    RAG_EMBEDDING_ENGINE,
99
    RAG_EMBEDDING_MODEL,
100
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
101
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
102
    ENABLE_RAG_HYBRID_SEARCH,
103
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
104
    RAG_RERANKING_MODEL,
105
    PDF_EXTRACT_IMAGES,
106
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
107
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
108
109
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
110
    DEVICE_TYPE,
111
112
113
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
114
    RAG_TEMPLATE,
115
    ENABLE_RAG_LOCAL_WEB_FETCH,
116
    YOUTUBE_LOADER_LANGUAGE,
Timothy J. Baek's avatar
Timothy J. Baek committed
117
    ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
118
    RAG_WEB_SEARCH_ENGINE,
Que Nguyen's avatar
Que Nguyen committed
119
    RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
120
121
122
    SEARXNG_QUERY_URL,
    GOOGLE_PSE_API_KEY,
    GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
123
    BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
124
125
126
    SERPSTACK_API_KEY,
    SERPSTACK_HTTPS,
    SERPER_API_KEY,
127
    SERPLY_API_KEY,
128
    TAVILY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
129
    RAG_WEB_SEARCH_RESULT_COUNT,
130
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
131
    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
132
)
133

134
135
from constants import ERROR_MESSAGES

136
137
138
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
139
140
app = FastAPI()

141
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
142

143
144
145
146
147
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
148
149
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
150

151
152
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
153

154
155
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
156
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
157
158
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
159

160

161
162
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
163

164
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
165

Steven Kreitzer's avatar
Steven Kreitzer committed
166

167
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
168
169
170
app.state.YOUTUBE_LOADER_TRANSLATION = None


Timothy J. Baek's avatar
Timothy J. Baek committed
171
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
Timothy J. Baek's avatar
Timothy J. Baek committed
172
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
Que Nguyen's avatar
Que Nguyen committed
173
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
Timothy J. Baek's avatar
Timothy J. Baek committed
174

Timothy J. Baek's avatar
Timothy J. Baek committed
175
176
177
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
Timothy J. Baek's avatar
Timothy J. Baek committed
178
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
181
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
182
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
183
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
184
185
186
187
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS


188
189
190
191
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
192
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
217
    app.state.config.RAG_EMBEDDING_MODEL,
218
219
220
221
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
222
    app.state.config.RAG_RERANKING_MODEL,
223
224
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
225

Timothy J. Baek's avatar
Timothy J. Baek committed
226
227

app.state.EMBEDDING_FUNCTION = get_embedding_function(
228
229
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
230
    app.state.sentence_transformer_ef,
231
232
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
233
    app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
234
235
)

Timothy J. Baek's avatar
Timothy J. Baek committed
236
237
origins = ["*"]

238

Timothy J. Baek's avatar
Timothy J. Baek committed
239
240
241
242
243
244
245
246
247
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
248
class CollectionNameForm(BaseModel):
249
250
251
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
252
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
253
254
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
255

256
257
258
259
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
260
261
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
262
263
    return {
        "status": True,
264
265
266
267
268
269
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
270
        "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
271
272
273
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
274
275
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
276
277
    return {
        "status": True,
278
279
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
280
        "openai_config": {
281
282
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
283
            "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
284
        },
285
286
287
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
288
289
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
290
291
    return {
        "status": True,
292
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
293
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
294
295


296
297
298
class OpenAIConfigForm(BaseModel):
    url: str
    key: str
299
    batch_size: Optional[int] = None
300
301


302
class EmbeddingModelUpdateForm(BaseModel):
303
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
304
    embedding_engine: str
305
306
307
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
308
309
@app.post("/embedding/update")
async def update_embedding_config(
310
311
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
312
    log.info(
313
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
314
    )
315
    try:
316
317
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
318

319
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
320
            if form_data.openai_config is not None:
321
322
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
323
324
325
326
327
                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
                    form_data.openai_config.batch_size
                    if form_data.openai_config.batch_size
                    else 1
                )
328

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
329
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
330

Timothy J. Baek's avatar
Timothy J. Baek committed
331
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
332
333
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
334
            app.state.sentence_transformer_ef,
335
336
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
337
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
338
339
        )

340
341
        return {
            "status": True,
342
343
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
344
            "openai_config": {
345
346
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
347
                "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
348
            },
349
350
351
352
353
354
355
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
356
357


Steven Kreitzer's avatar
Steven Kreitzer committed
358
359
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
360

Steven Kreitzer's avatar
Steven Kreitzer committed
361
362
363
364
365
366

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
367
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
368
369
    )
    try:
370
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
371

372
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
373
374
375

        return {
            "status": True,
376
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
377
378
379
380
381
382
383
384
385
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
386
387
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
388
389
    return {
        "status": True,
390
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
391
        "chunk": {
392
393
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
394
        },
395
        "youtube": {
396
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
397
398
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
399
        "web": {
Timothy J. Baek's avatar
Timothy J. Baek committed
400
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
401
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
402
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
403
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
404
405
406
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
407
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
408
409
410
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
411
                "serply_api_key": app.state.config.SERPLY_API_KEY,
412
                "tavily_api_key": app.state.config.TAVILY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
413
414
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
Timothy J. Baek's avatar
Timothy J. Baek committed
415
            },
Timothy J. Baek's avatar
Timothy J. Baek committed
416
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
417
418
419
420
421
422
423
424
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


425
426
427
428
429
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
430
class WebSearchConfig(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
431
    enabled: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
432
    engine: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
433
434
435
    searxng_query_url: Optional[str] = None
    google_pse_api_key: Optional[str] = None
    google_pse_engine_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
436
    brave_search_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
437
438
439
    serpstack_api_key: Optional[str] = None
    serpstack_https: Optional[bool] = None
    serper_api_key: Optional[str] = None
440
    serply_api_key: Optional[str] = None
441
    tavily_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
442
443
444
445
    result_count: Optional[int] = None
    concurrent_requests: Optional[int] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
446
447
448
449
450
class WebConfig(BaseModel):
    search: WebSearchConfig
    web_loader_ssl_verification: Optional[bool] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
451
class ConfigUpdateForm(BaseModel):
452
453
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
454
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
455
    web: Optional[WebConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
456
457
458
459


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
460
    app.state.config.PDF_EXTRACT_IMAGES = (
461
        form_data.pdf_extract_images
462
463
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
464
465
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
466
467
468
    if form_data.chunk is not None:
        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
469

Timothy J. Baek's avatar
Timothy J. Baek committed
470
471
472
    if form_data.youtube is not None:
        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
473

Timothy J. Baek's avatar
Timothy J. Baek committed
474
475
476
477
    if form_data.web is not None:
        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
            form_data.web.web_loader_ssl_verification
        )
478

Timothy J. Baek's avatar
Timothy J. Baek committed
479
        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
Timothy J. Baek's avatar
Timothy J. Baek committed
480
481
482
483
484
485
486
487
488
489
490
491
        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
        app.state.config.GOOGLE_PSE_ENGINE_ID = (
            form_data.web.search.google_pse_engine_id
        )
        app.state.config.BRAVE_SEARCH_API_KEY = (
            form_data.web.search.brave_search_api_key
        )
        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
492
        app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
493
        app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
Timothy J. Baek's avatar
Timothy J. Baek committed
494
495
496
497
        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
            form_data.web.search.concurrent_requests
        )
498

Timothy J. Baek's avatar
Timothy J. Baek committed
499
500
    return {
        "status": True,
501
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
502
        "chunk": {
503
504
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
505
        },
506
        "youtube": {
507
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
508
509
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
510
511
512
        "web": {
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
513
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
514
515
516
517
518
519
520
521
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
522
                "serply_api_key": app.state.config.SERPLY_API_KEY,
523
                "tavily_api_key": app.state.config.TAVILY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
524
525
526
527
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
            },
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
528
    }
529
530


Timothy J. Baek's avatar
Timothy J. Baek committed
531
532
533
534
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
535
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
536
537
538
    }


539
540
541
542
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
543
544
545
546
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
547
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
548
549


550
551
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
552
    r: Optional[float] = None
553
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
554
    hybrid: Optional[bool] = None
555
556
557
558
559
560


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
561
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
562
        form_data.template if form_data.template else RAG_TEMPLATE
563
    )
564
565
566
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
567
        form_data.hybrid if form_data.hybrid else False
568
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
569
570
    return {
        "status": True,
571
572
573
574
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
575
    }
576
577


578
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
579
580
    collection_name: str
    query: str
581
    k: Optional[int] = None
582
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
583
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
584
585


586
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
587
def query_doc_handler(
588
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
589
590
    user=Depends(get_current_user),
):
591
    try:
592
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
593
594
595
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
596
                embedding_function=app.state.EMBEDDING_FUNCTION,
597
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
598
                reranking_function=app.state.sentence_transformer_rf,
599
                r=(
600
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
601
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
602
603
604
605
606
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
607
                embedding_function=app.state.EMBEDDING_FUNCTION,
608
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
609
            )
610
    except Exception as e:
611
        log.exception(e)
612
613
614
615
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
616
617


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
618
619
620
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
621
    k: Optional[int] = None
622
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
623
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
624
625


626
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
627
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
628
629
630
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
631
    try:
632
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
633
634
635
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
636
                embedding_function=app.state.EMBEDDING_FUNCTION,
637
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
638
                reranking_function=app.state.sentence_transformer_rf,
639
                r=(
640
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
641
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
642
643
644
645
646
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
647
                embedding_function=app.state.EMBEDDING_FUNCTION,
648
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
649
            )
650

651
652
653
654
655
656
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
657
658


Timothy J. Baek's avatar
Timothy J. Baek committed
659
660
661
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
662
663
664
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
665
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
666
667
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


688
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
689
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
690
691
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
692
        loader = get_web_loader(
693
            form_data.url,
694
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
695
        )
696
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
697
698
699
700
701

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

702
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
703
704
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
705
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
706
707
            "filename": form_data.url,
        }
708
    except Exception as e:
709
        log.exception(e)
710
711
712
713
714
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

715

716
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
717
    # Check if the URL is valid
718
    if not validate_url(url):
719
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
720
    return SafeWebBaseLoader(
721
722
723
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
724
        continue_on_failure=True,
725
    )
726
727


728
729
730
731
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
732
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
733
734
735
736
737
738
739
740
741
742
743
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
744
745
746
747
748
749
750
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False

Timothy J. Baek's avatar
Timothy J. Baek committed
751

Timothy J. Baek's avatar
revert  
Timothy J. Baek committed
752
753
754
755
756
757
758
759
760
761
762
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


Timothy J. Baek's avatar
Timothy J. Baek committed
763
764
765
766
767
768
769
770
def search_web(engine: str, query: str) -> list[SearchResult]:
    """Search the web using a search engine and return the results as a list of SearchResult objects.
    Will look for a search engine API key in environment variables in the following order:
    - SEARXNG_QUERY_URL
    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
    - BRAVE_SEARCH_API_KEY
    - SERPSTACK_API_KEY
    - SERPER_API_KEY
771
    - SERPLY_API_KEY
772
    - TAVILY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
773
774
775
776
777
778
779
    Args:
        query (str): The query to search for
    """

    # TODO: add playwright to search the web
    if engine == "searxng":
        if app.state.config.SEARXNG_QUERY_URL:
Timothy J. Baek's avatar
Timothy J. Baek committed
780
781
782
783
            return search_searxng(
                app.state.config.SEARXNG_QUERY_URL,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
784
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
785
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
786
787
788
789
790
791
792
793
794
795
796
        else:
            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
    elif engine == "google_pse":
        if (
            app.state.config.GOOGLE_PSE_API_KEY
            and app.state.config.GOOGLE_PSE_ENGINE_ID
        ):
            return search_google_pse(
                app.state.config.GOOGLE_PSE_API_KEY,
                app.state.config.GOOGLE_PSE_ENGINE_ID,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
797
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
798
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
799
800
801
802
803
804
805
            )
        else:
            raise Exception(
                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
            )
    elif engine == "brave":
        if app.state.config.BRAVE_SEARCH_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
806
807
808
809
            return search_brave(
                app.state.config.BRAVE_SEARCH_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
810
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
811
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
812
813
814
815
816
817
818
        else:
            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
    elif engine == "serpstack":
        if app.state.config.SERPSTACK_API_KEY:
            return search_serpstack(
                app.state.config.SERPSTACK_API_KEY,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
819
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Que Nguyen's avatar
Que Nguyen committed
820
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
821
822
823
824
825
826
                https_enabled=app.state.config.SERPSTACK_HTTPS,
            )
        else:
            raise Exception("No SERPSTACK_API_KEY found in environment variables")
    elif engine == "serper":
        if app.state.config.SERPER_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
827
828
829
830
            return search_serper(
                app.state.config.SERPER_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
831
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
Timothy J. Baek's avatar
Timothy J. Baek committed
832
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
833
834
        else:
            raise Exception("No SERPER_API_KEY found in environment variables")
835
836
837
838
839
840
    elif engine == "serply":
        if app.state.config.SERPLY_API_KEY:
            return search_serply(
                app.state.config.SERPLY_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
841
                app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
842
843
844
            )
        else:
            raise Exception("No SERPLY_API_KEY found in environment variables")
845
    elif engine == "duckduckgo":
Timothy J. Baek's avatar
Timothy J. Baek committed
846
847
848
849
850
        return search_duckduckgo(
            query,
            app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
        )
851
852
853
854
855
856
857
858
859
    elif engine == "tavily":
        if app.state.config.TAVILY_API_KEY:
            return search_tavily(
                app.state.config.TAVILY_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
        else:
            raise Exception("No TAVILY_API_KEY found in environment variables")
860
861
    elif engine == "jina":
        return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
Timothy J. Baek's avatar
Timothy J. Baek committed
862
863
864
865
    else:
        raise Exception("No search engine API key found in environment variables")


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
866
867
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
868
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
869
870
871
        logging.info(
            f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
872
873
874
875
876
877
878
879
880
881
882
883
884
        web_results = search_web(
            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
        )
    except Exception as e:
        log.exception(e)

        print(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
        )

    try:
885
886
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
887
        data = loader.load()
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


907
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
908

909
    text_splitter = RecursiveCharacterTextSplitter(
910
911
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
912
913
        add_start_index=True,
    )
914

915
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
916
917

    if len(docs) > 0:
918
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
919
920
921
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
922
923
924


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
925
    text, metadata, collection_name, overwrite: bool = False
926
927
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
928
929
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
930
931
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
932
    docs = text_splitter.create_documents([text], metadatas=[metadata])
933
934
935
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
936
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
937
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
938

939
940
941
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

mindspawn's avatar
mindspawn committed
942
943
944
945
946
947
948
    # ChromaDB does not like datetime formats
    # for meta-data so convert them to string.
    for metadata in metadatas:
        for key, value in metadata.items():
            if isinstance(value, datetime):
                metadata[key] = str(value)

949
950
951
952
    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
953
                    log.info(f"deleting existing collection {collection_name}")
954
955
                    CHROMA_CLIENT.delete_collection(name=collection_name)

956
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
957

Timothy J. Baek's avatar
Timothy J. Baek committed
958
        embedding_func = get_embedding_function(
959
960
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
961
            app.state.sentence_transformer_ef,
962
963
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
964
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Steven Kreitzer's avatar
Steven Kreitzer committed
965
966
967
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
968
        embeddings = embedding_func(embedding_texts)
969
970
971

        for batch in create_batches(
            api=CHROMA_CLIENT,
972
            ids=[str(uuid.uuid4()) for _ in texts],
973
974
975
976
977
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
978

979
        return True
980
    except Exception as e:
981
        log.exception(e)
982
983
984
985
986
987
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


988
989
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
mindspawn's avatar
mindspawn committed
1035
        "msg",
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1036
1037
1038
    ]

    if file_ext == "pdf":
1039
        loader = PyPDFLoader(
1040
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
1041
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1042
1043
1044
1045
1046
1047
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
1048
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
1049
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1050
1051
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
1052
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1053
1054
        loader = UnstructuredEPubLoader(file_path)
    elif (
1055
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1056
1057
1058
1059
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
1060
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1061
1062
1063
1064
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1065
1066
1067
1068
1069
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
mindspawn's avatar
mindspawn committed
1070
1071
    elif file_ext == "msg":
        loader = OutlookMessageLoader(file_path)
1072
1073
1074
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
1075
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1076
    else:
1077
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1078
1079
1080
1081
1082
        known_type = False

    return loader, known_type


1083
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
1084
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
1085
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
1086
1087
1088
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
1089
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
1090

1091
    log.info(f"file.content_type: {file.content_type}")
1092
    try:
1093
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
1094
        filename = os.path.basename(unsanitized_filename)
1095

Timothy J. Baek's avatar
Timothy J. Baek committed
1096
        file_path = f"{UPLOAD_DIR}/{filename}"
1097

1098
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
1099
        with open(file_path, "wb") as f:
1100
1101
1102
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1103
1104
1105
1106
1107
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1108
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1109
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
1122
1123
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
1124
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
1125
            )
1126
    except Exception as e:
1127
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
1138
1139


Timothy J. Baek's avatar
Timothy J. Baek committed
1140
1141
class ProcessDocForm(BaseModel):
    file_id: str
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1142
    collection_name: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154


@app.post("/process/doc")
def process_doc(
    form_data: ProcessDocForm,
    user=Depends(get_current_user),
):
    try:
        file = Files.get_file_by_id(form_data.file_id)
        file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")

        f = open(file_path, "rb")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1155
1156

        collection_name = form_data.collection_name
Timothy J. Baek's avatar
Timothy J. Baek committed
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

        loader, known_type = get_loader(
            file.filename, file.meta.get("content_type"), file_path
        )
        data = loader.load()

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "known_type": known_type,
                }
        except Exception as e:
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail=e,
            )
    except Exception as e:
        log.exception(e)
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )


1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
1210
1211
1212
1213
1214
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


1225
1226
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
1227
1228
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
1229
1230
1231
1232
1233
1234
1235
1236
1237
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1238
1239
1240
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
1241
1242
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
1267
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
1268
1269
1270
1271
1272
1273
1274
1275
1276
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
1277
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
1278
                    pass
1279

1280
        except Exception as e:
1281
            log.exception(e)
1282
1283
1284
1285

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1286
@app.get("/reset/db")
1287
1288
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
1289
1290


Timothy J. Baek's avatar
Timothy J. Baek committed
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
@app.get("/reset/uploads")
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    try:
        # Check if the directory exists
        if os.path.exists(folder):
            # Iterate over all the files and directories in the specified directory
            for filename in os.listdir(folder):
                file_path = os.path.join(folder, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)  # Remove the file or link
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)  # Remove the directory
                except Exception as e:
                    print(f"Failed to delete {file_path}. Reason: {e}")
        else:
            print(f"The directory {folder} does not exist")
    except Exception as e:
        print(f"Failed to process the directory {folder}. Reason: {e}")

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1315
@app.get("/reset")
1316
1317
1318
1319
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
1320
        try:
1321
1322
1323
1324
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1325
        except Exception as e:
1326
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1327

1328
1329
1330
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1331
        log.exception(e)
1332
1333

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1334

Timothy J. Baek's avatar
Timothy J. Baek committed
1335

1336
1337
class SafeWebBaseLoader(WebBaseLoader):
    """WebBaseLoader with enhanced error handling for URLs."""
Timothy J. Baek's avatar
Timothy J. Baek committed
1338

1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
    def lazy_load(self) -> Iterator[Document]:
        """Lazy load text from the url(s) in web_path with error handling."""
        for path in self.web_paths:
            try:
                soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
                text = soup.get_text(**self.bs_get_text_kwargs)

                # Build metadata
                metadata = {"source": path}
                if title := soup.find("title"):
                    metadata["title"] = title.get_text()
                if description := soup.find("meta", attrs={"name": "description"}):
Timothy J. Baek's avatar
Timothy J. Baek committed
1351
1352
1353
                    metadata["description"] = description.get(
                        "content", "No description found."
                    )
1354
1355
                if html := soup.find("html"):
                    metadata["language"] = html.get("lang", "No language found.")
Timothy J. Baek's avatar
Timothy J. Baek committed
1356

1357
1358
1359
1360
                yield Document(page_content=text, metadata=metadata)
            except Exception as e:
                # Log the error and continue with the next URL
                log.error(f"Error loading {path}: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
1361
1362


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1363
1364
1365
1366
1367
1368
1369
1370
1371
if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}