main.py 38.3 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
14
from typing import List, Union, Sequence
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
    YoutubeLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
)
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

36
37
38
39
40
import validators
import urllib.parse
import socket


41
42
from pydantic import BaseModel
from typing import Optional
43
import mimetypes
44
import uuid
45
46
import json

47
import sentence_transformers
48

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
49
from apps.webui.models.documents import (
50
51
52
53
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
54

55
from apps.rag.utils import (
56
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
61
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
62
)
Timothy J. Baek's avatar
Timothy J. Baek committed
63

Timothy J. Baek's avatar
Timothy J. Baek committed
64
65
66
67
68
69
70
71
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack


72
73
74
75
76
77
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
78
from utils.utils import get_current_user, get_admin_user
79

80
from config import (
81
    AppConfig,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
82
    ENV,
83
    SRC_LOG_LEVELS,
84
85
    UPLOAD_DIR,
    DOCS_DIR,
86
87
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
88
    RAG_EMBEDDING_ENGINE,
89
    RAG_EMBEDDING_MODEL,
90
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
91
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
92
    ENABLE_RAG_HYBRID_SEARCH,
93
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
94
    RAG_RERANKING_MODEL,
95
    PDF_EXTRACT_IMAGES,
96
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
97
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
98
99
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
100
    DEVICE_TYPE,
101
102
103
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
104
    RAG_TEMPLATE,
105
    ENABLE_RAG_LOCAL_WEB_FETCH,
106
    YOUTUBE_LOADER_LANGUAGE,
Timothy J. Baek's avatar
Timothy J. Baek committed
107
    ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
108
    RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
109
110
111
    SEARXNG_QUERY_URL,
    GOOGLE_PSE_API_KEY,
    GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
112
    BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
115
116
    SERPSTACK_API_KEY,
    SERPSTACK_HTTPS,
    SERPER_API_KEY,
    RAG_WEB_SEARCH_RESULT_COUNT,
117
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
118
    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
119
)
120

121
122
from constants import ERROR_MESSAGES

123
124
125
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
126
127
app = FastAPI()

128
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
129

130
131
132
133
134
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
135
136
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
137

138
139
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
140

141
142
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
143
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
144
145
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
146

147

148
149
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
150

151
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
152

Steven Kreitzer's avatar
Steven Kreitzer committed
153

154
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
155
156
157
app.state.YOUTUBE_LOADER_TRANSLATION = None


Timothy J. Baek's avatar
Timothy J. Baek committed
158
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
Timothy J. Baek's avatar
Timothy J. Baek committed
159
160
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE

Timothy J. Baek's avatar
Timothy J. Baek committed
161
162
163
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
Timothy J. Baek's avatar
Timothy J. Baek committed
164
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
165
166
167
168
169
170
171
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS


172
173
174
175
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
176
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
201
    app.state.config.RAG_EMBEDDING_MODEL,
202
203
204
205
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
206
    app.state.config.RAG_RERANKING_MODEL,
207
208
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
209

Timothy J. Baek's avatar
Timothy J. Baek committed
210
211

app.state.EMBEDDING_FUNCTION = get_embedding_function(
212
213
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
214
    app.state.sentence_transformer_ef,
215
216
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
217
    app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
218
219
)

Timothy J. Baek's avatar
Timothy J. Baek committed
220
221
origins = ["*"]

222

Timothy J. Baek's avatar
Timothy J. Baek committed
223
224
225
226
227
228
229
230
231
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
232
class CollectionNameForm(BaseModel):
233
234
235
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
236
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
237
238
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
239

240
241
242
243
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
244
245
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
246
247
    return {
        "status": True,
248
249
250
251
252
253
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
254
        "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
255
256
257
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
258
259
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
260
261
    return {
        "status": True,
262
263
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
264
        "openai_config": {
265
266
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
267
            "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
268
        },
269
270
271
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
272
273
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
274
275
    return {
        "status": True,
276
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
277
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
278
279


280
281
282
class OpenAIConfigForm(BaseModel):
    url: str
    key: str
283
    batch_size: Optional[int] = None
284
285


286
class EmbeddingModelUpdateForm(BaseModel):
287
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
288
    embedding_engine: str
289
290
291
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
292
293
@app.post("/embedding/update")
async def update_embedding_config(
294
295
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
296
    log.info(
297
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
298
    )
299
    try:
300
301
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
302

303
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
304
            if form_data.openai_config is not None:
305
306
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
307
308
309
310
311
                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
                    form_data.openai_config.batch_size
                    if form_data.openai_config.batch_size
                    else 1
                )
312

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
313
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
314

Timothy J. Baek's avatar
Timothy J. Baek committed
315
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
316
317
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
318
            app.state.sentence_transformer_ef,
319
320
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
321
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
322
323
        )

324
325
        return {
            "status": True,
326
327
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
328
            "openai_config": {
329
330
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
331
                "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
332
            },
333
334
335
336
337
338
339
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
340
341


Steven Kreitzer's avatar
Steven Kreitzer committed
342
343
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
344

Steven Kreitzer's avatar
Steven Kreitzer committed
345
346
347
348
349
350

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
351
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
352
353
    )
    try:
354
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
355

356
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
357
358
359

        return {
            "status": True,
360
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
361
362
363
364
365
366
367
368
369
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
370
371
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
372
373
    return {
        "status": True,
374
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
375
        "chunk": {
376
377
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
378
        },
379
        "youtube": {
380
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
381
382
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
383
        "web": {
Timothy J. Baek's avatar
Timothy J. Baek committed
384
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
385
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
386
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
387
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
388
389
390
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
391
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
392
393
394
395
396
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
Timothy J. Baek's avatar
Timothy J. Baek committed
397
            },
Timothy J. Baek's avatar
Timothy J. Baek committed
398
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
399
400
401
402
403
404
405
406
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


407
408
409
410
411
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
412
class WebSearchConfig(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
413
    enabled: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
414
    engine: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
415
416
417
    searxng_query_url: Optional[str] = None
    google_pse_api_key: Optional[str] = None
    google_pse_engine_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
418
    brave_search_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
419
420
421
422
423
424
425
    serpstack_api_key: Optional[str] = None
    serpstack_https: Optional[bool] = None
    serper_api_key: Optional[str] = None
    result_count: Optional[int] = None
    concurrent_requests: Optional[int] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
426
427
428
429
430
class WebConfig(BaseModel):
    search: WebSearchConfig
    web_loader_ssl_verification: Optional[bool] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
431
class ConfigUpdateForm(BaseModel):
432
433
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
434
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
435
    web: Optional[WebConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
436
437
438
439


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
440
    app.state.config.PDF_EXTRACT_IMAGES = (
441
        form_data.pdf_extract_images
442
443
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
444
445
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
446
447
448
    if form_data.chunk is not None:
        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
449

Timothy J. Baek's avatar
Timothy J. Baek committed
450
451
452
    if form_data.youtube is not None:
        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
453

Timothy J. Baek's avatar
Timothy J. Baek committed
454
455
456
457
    if form_data.web is not None:
        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
            form_data.web.web_loader_ssl_verification
        )
458

Timothy J. Baek's avatar
Timothy J. Baek committed
459
        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
Timothy J. Baek's avatar
Timothy J. Baek committed
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
        app.state.config.GOOGLE_PSE_ENGINE_ID = (
            form_data.web.search.google_pse_engine_id
        )
        app.state.config.BRAVE_SEARCH_API_KEY = (
            form_data.web.search.brave_search_api_key
        )
        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
            form_data.web.search.concurrent_requests
        )
476

Timothy J. Baek's avatar
Timothy J. Baek committed
477
478
    return {
        "status": True,
479
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
480
        "chunk": {
481
482
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
483
        },
484
        "youtube": {
485
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
486
487
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
488
489
490
        "web": {
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
491
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
492
493
494
495
496
497
498
499
500
501
502
503
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
            },
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
504
    }
505
506


Timothy J. Baek's avatar
Timothy J. Baek committed
507
508
509
510
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
511
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
512
513
514
    }


515
516
517
518
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
519
520
521
522
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
523
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
524
525


526
527
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
528
    r: Optional[float] = None
529
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
530
    hybrid: Optional[bool] = None
531
532
533
534
535
536


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
537
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
538
        form_data.template if form_data.template else RAG_TEMPLATE
539
    )
540
541
542
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
543
        form_data.hybrid if form_data.hybrid else False
544
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
545
546
    return {
        "status": True,
547
548
549
550
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
551
    }
552
553


554
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
555
556
    collection_name: str
    query: str
557
    k: Optional[int] = None
558
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
559
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
560
561


562
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
563
def query_doc_handler(
564
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
565
566
    user=Depends(get_current_user),
):
567
    try:
568
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
569
570
571
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
572
                embedding_function=app.state.EMBEDDING_FUNCTION,
573
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
574
                reranking_function=app.state.sentence_transformer_rf,
575
                r=(
576
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
577
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
578
579
580
581
582
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
583
                embedding_function=app.state.EMBEDDING_FUNCTION,
584
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
585
            )
586
    except Exception as e:
587
        log.exception(e)
588
589
590
591
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
592
593


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
594
595
596
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
597
    k: Optional[int] = None
598
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
599
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
600
601


602
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
603
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
604
605
606
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
607
    try:
608
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
609
610
611
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
612
                embedding_function=app.state.EMBEDDING_FUNCTION,
613
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
614
                reranking_function=app.state.sentence_transformer_rf,
615
                r=(
616
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
617
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
618
619
620
621
622
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
623
                embedding_function=app.state.EMBEDDING_FUNCTION,
624
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
625
            )
626

627
628
629
630
631
632
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
633
634


Timothy J. Baek's avatar
Timothy J. Baek committed
635
636
637
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
638
639
640
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
641
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
642
643
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


664
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
665
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
666
667
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
668
        loader = get_web_loader(
669
            form_data.url,
670
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
671
        )
672
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
673
674
675
676
677

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

678
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
679
680
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
681
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
682
683
            "filename": form_data.url,
        }
684
    except Exception as e:
685
        log.exception(e)
686
687
688
689
690
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

691

692
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
693
    # Check if the URL is valid
694
    if not validate_url(url):
695
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
696
697
698
699
    return WebBaseLoader(
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
700
        continue_on_failure=True,
701
    )
702
703


704
705
706
707
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
708
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False


728
729
730
731
732
733
734
735
736
737
738
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


Timothy J. Baek's avatar
Timothy J. Baek committed
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
def search_web(engine: str, query: str) -> list[SearchResult]:
    """Search the web using a search engine and return the results as a list of SearchResult objects.
    Will look for a search engine API key in environment variables in the following order:
    - SEARXNG_QUERY_URL
    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
    - BRAVE_SEARCH_API_KEY
    - SERPSTACK_API_KEY
    - SERPER_API_KEY

    Args:
        query (str): The query to search for
    """

    # TODO: add playwright to search the web
    if engine == "searxng":
        if app.state.config.SEARXNG_QUERY_URL:
Timothy J. Baek's avatar
Timothy J. Baek committed
755
756
757
758
759
            return search_searxng(
                app.state.config.SEARXNG_QUERY_URL,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
760
761
762
763
764
765
766
767
768
769
770
        else:
            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
    elif engine == "google_pse":
        if (
            app.state.config.GOOGLE_PSE_API_KEY
            and app.state.config.GOOGLE_PSE_ENGINE_ID
        ):
            return search_google_pse(
                app.state.config.GOOGLE_PSE_API_KEY,
                app.state.config.GOOGLE_PSE_ENGINE_ID,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
771
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
772
773
774
775
776
777
778
            )
        else:
            raise Exception(
                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
            )
    elif engine == "brave":
        if app.state.config.BRAVE_SEARCH_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
779
780
781
782
783
            return search_brave(
                app.state.config.BRAVE_SEARCH_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
784
785
786
787
788
789
790
        else:
            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
    elif engine == "serpstack":
        if app.state.config.SERPSTACK_API_KEY:
            return search_serpstack(
                app.state.config.SERPSTACK_API_KEY,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
791
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
792
793
794
795
796
797
                https_enabled=app.state.config.SERPSTACK_HTTPS,
            )
        else:
            raise Exception("No SERPSTACK_API_KEY found in environment variables")
    elif engine == "serper":
        if app.state.config.SERPER_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
798
799
800
801
802
            return search_serper(
                app.state.config.SERPER_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
803
804
805
806
807
808
        else:
            raise Exception("No SERPER_API_KEY found in environment variables")
    else:
        raise Exception("No search engine API key found in environment variables")


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
809
810
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
811
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
812
813
814
815
816
817
818
819
820
821
822
823
824
        web_results = search_web(
            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
        )
    except Exception as e:
        log.exception(e)

        print(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
        )

    try:
825
826
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
827
        data = loader.load()
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


847
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
848

849
    text_splitter = RecursiveCharacterTextSplitter(
850
851
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
852
853
        add_start_index=True,
    )
854

855
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
856
857

    if len(docs) > 0:
858
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
859
860
861
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
862
863
864


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
865
    text, metadata, collection_name, overwrite: bool = False
866
867
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
868
869
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
870
871
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
872
    docs = text_splitter.create_documents([text], metadatas=[metadata])
873
874
875
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
876
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
877
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
878

879
880
881
882
883
884
885
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
886
                    log.info(f"deleting existing collection {collection_name}")
887
888
                    CHROMA_CLIENT.delete_collection(name=collection_name)

889
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
890

Timothy J. Baek's avatar
Timothy J. Baek committed
891
        embedding_func = get_embedding_function(
892
893
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
894
            app.state.sentence_transformer_ef,
895
896
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
897
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Steven Kreitzer's avatar
Steven Kreitzer committed
898
899
900
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
901
        embeddings = embedding_func(embedding_texts)
902
903
904

        for batch in create_batches(
            api=CHROMA_CLIENT,
905
            ids=[str(uuid.uuid4()) for _ in texts],
906
907
908
909
910
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
911

912
        return True
913
    except Exception as e:
914
        log.exception(e)
915
916
917
918
919
920
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


921
922
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
971
        loader = PyPDFLoader(
972
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
973
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
974
975
976
977
978
979
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
980
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
981
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
982
983
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
984
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
985
986
        loader = UnstructuredEPubLoader(file_path)
    elif (
987
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
988
989
990
991
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
992
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
993
994
995
996
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
997
998
999
1000
1001
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
1002
1003
1004
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
1005
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1006
    else:
1007
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1008
1009
1010
1011
1012
        known_type = False

    return loader, known_type


1013
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
1014
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
1015
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
1016
1017
1018
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
1019
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
1020

1021
    log.info(f"file.content_type: {file.content_type}")
1022
    try:
1023
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
1024
        filename = os.path.basename(unsanitized_filename)
1025

Timothy J. Baek's avatar
Timothy J. Baek committed
1026
        file_path = f"{UPLOAD_DIR}/{filename}"
1027

1028
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
1029
        with open(file_path, "wb") as f:
1030
1031
1032
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1033
1034
1035
1036
1037
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1038
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1039
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
1052
1053
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
1054
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
1055
            )
1056
    except Exception as e:
1057
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
1068
1069


1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
1086
1087
1088
1089
1090
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


1101
1102
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
1103
1104
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
1105
1106
1107
1108
1109
1110
1111
1112
1113
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1114
1115
1116
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
1117
1118
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
1143
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
1144
1145
1146
1147
1148
1149
1150
1151
1152
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
1153
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
1154
                    pass
1155

1156
        except Exception as e:
1157
            log.exception(e)
1158
1159
1160
1161

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1162
@app.get("/reset/db")
1163
1164
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
1165
1166
1167


@app.get("/reset")
1168
1169
1170
1171
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
1172
        try:
1173
1174
1175
1176
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1177
        except Exception as e:
1178
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1179

1180
1181
1182
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1183
        log.exception(e)
1184
1185

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196


if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}