main.py 40.1 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
14
from typing import List, Union, Sequence
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
    YoutubeLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
)
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

36
37
38
39
40
import validators
import urllib.parse
import socket


41
42
from pydantic import BaseModel
from typing import Optional
43
import mimetypes
44
import uuid
45
46
import json

47
import sentence_transformers
48

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
49
from apps.webui.models.documents import (
50
51
52
53
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
54

55
from apps.rag.utils import (
56
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
61
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
62
)
Timothy J. Baek's avatar
Timothy J. Baek committed
63

Timothy J. Baek's avatar
Timothy J. Baek committed
64
65
66
67
68
69
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
70
from apps.rag.search.serply import search_serply
Timothy J. Baek's avatar
Timothy J. Baek committed
71

72
73
74
75
76
77
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
78
from utils.utils import get_current_user, get_admin_user
79

80
from config import (
81
    AppConfig,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
82
    ENV,
83
    SRC_LOG_LEVELS,
84
85
    UPLOAD_DIR,
    DOCS_DIR,
86
87
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
88
    RAG_EMBEDDING_ENGINE,
89
    RAG_EMBEDDING_MODEL,
90
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
91
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
92
    ENABLE_RAG_HYBRID_SEARCH,
93
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
94
    RAG_RERANKING_MODEL,
95
    PDF_EXTRACT_IMAGES,
96
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
97
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
98
99
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
100
    DEVICE_TYPE,
101
102
103
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
104
    RAG_TEMPLATE,
105
    ENABLE_RAG_LOCAL_WEB_FETCH,
106
    YOUTUBE_LOADER_LANGUAGE,
Timothy J. Baek's avatar
Timothy J. Baek committed
107
    ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
108
    RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
109
110
111
    SEARXNG_QUERY_URL,
    GOOGLE_PSE_API_KEY,
    GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
112
    BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
115
    SERPSTACK_API_KEY,
    SERPSTACK_HTTPS,
    SERPER_API_KEY,
116
    SERPLY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
117
    RAG_WEB_SEARCH_RESULT_COUNT,
118
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
119
    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
120
)
121

122
123
from constants import ERROR_MESSAGES

124
125
126
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
127
128
app = FastAPI()

129
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
130

131
132
133
134
135
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
136
137
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
138

139
140
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
141

142
143
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
144
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
145
146
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
147

148

149
150
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
151

152
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
153

Steven Kreitzer's avatar
Steven Kreitzer committed
154

155
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
156
157
158
app.state.YOUTUBE_LOADER_TRANSLATION = None


Timothy J. Baek's avatar
Timothy J. Baek committed
159
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
Timothy J. Baek's avatar
Timothy J. Baek committed
160
161
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE

Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
164
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
Timothy J. Baek's avatar
Timothy J. Baek committed
165
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
168
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
169
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
170
171
172
173
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS


174
175
176
177
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
178
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
203
    app.state.config.RAG_EMBEDDING_MODEL,
204
205
206
207
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
208
    app.state.config.RAG_RERANKING_MODEL,
209
210
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
211

Timothy J. Baek's avatar
Timothy J. Baek committed
212
213

app.state.EMBEDDING_FUNCTION = get_embedding_function(
214
215
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
216
    app.state.sentence_transformer_ef,
217
218
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
219
    app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
220
221
)

Timothy J. Baek's avatar
Timothy J. Baek committed
222
223
origins = ["*"]

224

Timothy J. Baek's avatar
Timothy J. Baek committed
225
226
227
228
229
230
231
232
233
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
234
class CollectionNameForm(BaseModel):
235
236
237
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
238
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
239
240
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
241

242
243
244
245
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
246
247
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
248
249
    return {
        "status": True,
250
251
252
253
254
255
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
256
        "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
257
258
259
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
260
261
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
262
263
    return {
        "status": True,
264
265
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
266
        "openai_config": {
267
268
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
269
            "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
270
        },
271
272
273
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
274
275
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
276
277
    return {
        "status": True,
278
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
279
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
280
281


282
283
284
class OpenAIConfigForm(BaseModel):
    url: str
    key: str
285
    batch_size: Optional[int] = None
286
287


288
class EmbeddingModelUpdateForm(BaseModel):
289
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
290
    embedding_engine: str
291
292
293
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
294
295
@app.post("/embedding/update")
async def update_embedding_config(
296
297
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
298
    log.info(
299
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
300
    )
301
    try:
302
303
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
304

305
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
306
            if form_data.openai_config is not None:
307
308
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
309
310
311
312
313
                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
                    form_data.openai_config.batch_size
                    if form_data.openai_config.batch_size
                    else 1
                )
314

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
315
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
316

Timothy J. Baek's avatar
Timothy J. Baek committed
317
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
318
319
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
320
            app.state.sentence_transformer_ef,
321
322
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
323
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
324
325
        )

326
327
        return {
            "status": True,
328
329
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
330
            "openai_config": {
331
332
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
333
                "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
334
            },
335
336
337
338
339
340
341
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
342
343


Steven Kreitzer's avatar
Steven Kreitzer committed
344
345
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
346

Steven Kreitzer's avatar
Steven Kreitzer committed
347
348
349
350
351
352

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
353
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
354
355
    )
    try:
356
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
357

358
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
359
360
361

        return {
            "status": True,
362
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
363
364
365
366
367
368
369
370
371
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
372
373
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
374
375
    return {
        "status": True,
376
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
377
        "chunk": {
378
379
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
380
        },
381
        "youtube": {
382
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
383
384
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
385
        "web": {
Timothy J. Baek's avatar
Timothy J. Baek committed
386
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
387
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
388
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
389
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
390
391
392
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
393
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
394
395
396
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
397
                "serply_api_key": app.state.config.SERPLY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
398
399
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
Timothy J. Baek's avatar
Timothy J. Baek committed
400
            },
Timothy J. Baek's avatar
Timothy J. Baek committed
401
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
402
403
404
405
406
407
408
409
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


410
411
412
413
414
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
415
class WebSearchConfig(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
416
    enabled: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
417
    engine: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
418
419
420
    searxng_query_url: Optional[str] = None
    google_pse_api_key: Optional[str] = None
    google_pse_engine_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
421
    brave_search_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
422
423
424
    serpstack_api_key: Optional[str] = None
    serpstack_https: Optional[bool] = None
    serper_api_key: Optional[str] = None
425
    serply_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
426
427
428
429
    result_count: Optional[int] = None
    concurrent_requests: Optional[int] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
430
431
432
433
434
class WebConfig(BaseModel):
    search: WebSearchConfig
    web_loader_ssl_verification: Optional[bool] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
435
class ConfigUpdateForm(BaseModel):
436
437
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
438
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
439
    web: Optional[WebConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
440
441
442
443


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
444
    app.state.config.PDF_EXTRACT_IMAGES = (
445
        form_data.pdf_extract_images
446
447
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
448
449
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
450
451
452
    if form_data.chunk is not None:
        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
453

Timothy J. Baek's avatar
Timothy J. Baek committed
454
455
456
    if form_data.youtube is not None:
        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
457

Timothy J. Baek's avatar
Timothy J. Baek committed
458
459
460
461
    if form_data.web is not None:
        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
            form_data.web.web_loader_ssl_verification
        )
462

Timothy J. Baek's avatar
Timothy J. Baek committed
463
        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
Timothy J. Baek's avatar
Timothy J. Baek committed
464
465
466
467
468
469
470
471
472
473
474
475
        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
        app.state.config.GOOGLE_PSE_ENGINE_ID = (
            form_data.web.search.google_pse_engine_id
        )
        app.state.config.BRAVE_SEARCH_API_KEY = (
            form_data.web.search.brave_search_api_key
        )
        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
476
        app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
Timothy J. Baek's avatar
Timothy J. Baek committed
477
478
479
480
        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
            form_data.web.search.concurrent_requests
        )
481

Timothy J. Baek's avatar
Timothy J. Baek committed
482
483
    return {
        "status": True,
484
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
485
        "chunk": {
486
487
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
488
        },
489
        "youtube": {
490
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
491
492
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
493
494
495
        "web": {
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
496
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
497
498
499
500
501
502
503
504
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
505
                "serply_api_key": app.state.config.SERPLY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
506
507
508
509
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
            },
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
510
    }
511
512


Timothy J. Baek's avatar
Timothy J. Baek committed
513
514
515
516
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
517
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
518
519
520
    }


521
522
523
524
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
525
526
527
528
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
529
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
530
531


532
533
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
534
    r: Optional[float] = None
535
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
536
    hybrid: Optional[bool] = None
537
538
539
540
541
542


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
543
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
544
        form_data.template if form_data.template else RAG_TEMPLATE
545
    )
546
547
548
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
549
        form_data.hybrid if form_data.hybrid else False
550
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
551
552
    return {
        "status": True,
553
554
555
556
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
557
    }
558
559


560
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
561
562
    collection_name: str
    query: str
563
    k: Optional[int] = None
564
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
565
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
566
567


568
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
569
def query_doc_handler(
570
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
571
572
    user=Depends(get_current_user),
):
573
    try:
574
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
575
576
577
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
578
                embedding_function=app.state.EMBEDDING_FUNCTION,
579
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
580
                reranking_function=app.state.sentence_transformer_rf,
581
                r=(
582
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
583
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
584
585
586
587
588
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
589
                embedding_function=app.state.EMBEDDING_FUNCTION,
590
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
591
            )
592
    except Exception as e:
593
        log.exception(e)
594
595
596
597
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
598
599


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
600
601
602
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
603
    k: Optional[int] = None
604
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
605
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
606
607


608
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
609
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
610
611
612
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
613
    try:
614
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
615
616
617
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
618
                embedding_function=app.state.EMBEDDING_FUNCTION,
619
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
620
                reranking_function=app.state.sentence_transformer_rf,
621
                r=(
622
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
623
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
624
625
626
627
628
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
629
                embedding_function=app.state.EMBEDDING_FUNCTION,
630
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
631
            )
632

633
634
635
636
637
638
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
639
640


Timothy J. Baek's avatar
Timothy J. Baek committed
641
642
643
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
644
645
646
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
647
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
648
649
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


670
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
671
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
672
673
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
674
        loader = get_web_loader(
675
            form_data.url,
676
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
677
        )
678
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
679
680
681
682
683

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

684
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
685
686
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
687
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
688
689
            "filename": form_data.url,
        }
690
    except Exception as e:
691
        log.exception(e)
692
693
694
695
696
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

697

698
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
699
    # Check if the URL is valid
700
    if not validate_url(url):
701
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
702
703
704
705
    return WebBaseLoader(
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
706
        continue_on_failure=True,
707
    )
708
709


710
711
712
713
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
714
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False


734
735
736
737
738
739
740
741
742
743
744
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


Timothy J. Baek's avatar
Timothy J. Baek committed
745
746
747
748
749
750
751
752
def search_web(engine: str, query: str) -> list[SearchResult]:
    """Search the web using a search engine and return the results as a list of SearchResult objects.
    Will look for a search engine API key in environment variables in the following order:
    - SEARXNG_QUERY_URL
    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
    - BRAVE_SEARCH_API_KEY
    - SERPSTACK_API_KEY
    - SERPER_API_KEY
753
    - SERPLY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
754
755
756
757
758
759
760
761

    Args:
        query (str): The query to search for
    """

    # TODO: add playwright to search the web
    if engine == "searxng":
        if app.state.config.SEARXNG_QUERY_URL:
Timothy J. Baek's avatar
Timothy J. Baek committed
762
763
764
765
766
            return search_searxng(
                app.state.config.SEARXNG_QUERY_URL,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
767
768
769
770
771
772
773
774
775
776
777
        else:
            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
    elif engine == "google_pse":
        if (
            app.state.config.GOOGLE_PSE_API_KEY
            and app.state.config.GOOGLE_PSE_ENGINE_ID
        ):
            return search_google_pse(
                app.state.config.GOOGLE_PSE_API_KEY,
                app.state.config.GOOGLE_PSE_ENGINE_ID,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
778
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
779
780
781
782
783
784
785
            )
        else:
            raise Exception(
                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
            )
    elif engine == "brave":
        if app.state.config.BRAVE_SEARCH_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
786
787
788
789
790
            return search_brave(
                app.state.config.BRAVE_SEARCH_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
791
792
793
794
795
796
797
        else:
            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
    elif engine == "serpstack":
        if app.state.config.SERPSTACK_API_KEY:
            return search_serpstack(
                app.state.config.SERPSTACK_API_KEY,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
798
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
799
800
801
802
803
804
                https_enabled=app.state.config.SERPSTACK_HTTPS,
            )
        else:
            raise Exception("No SERPSTACK_API_KEY found in environment variables")
    elif engine == "serper":
        if app.state.config.SERPER_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
805
806
807
808
809
            return search_serper(
                app.state.config.SERPER_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
810
811
        else:
            raise Exception("No SERPER_API_KEY found in environment variables")
812
813
814
815
816
817
818
819
820
    elif engine == "serply":
        if app.state.config.SERPLY_API_KEY:
            return search_serply(
                app.state.config.SERPLY_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
        else:
            raise Exception("No SERPLY_API_KEY found in environment variables")
Timothy J. Baek's avatar
Timothy J. Baek committed
821
822
823
824
    else:
        raise Exception("No search engine API key found in environment variables")


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
825
826
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
827
    try:
828
        logging.info(f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}")
Timothy J. Baek's avatar
Timothy J. Baek committed
829
830
831
832
833
834
835
836
837
838
839
840
841
        web_results = search_web(
            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
        )
    except Exception as e:
        log.exception(e)

        print(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
        )

    try:
842
843
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
844
        data = loader.load()
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


864
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
865

866
    text_splitter = RecursiveCharacterTextSplitter(
867
868
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
869
870
        add_start_index=True,
    )
871

872
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
873
874

    if len(docs) > 0:
875
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
876
877
878
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
879
880
881


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
882
    text, metadata, collection_name, overwrite: bool = False
883
884
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
885
886
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
887
888
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
889
    docs = text_splitter.create_documents([text], metadatas=[metadata])
890
891
892
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
893
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
894
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
895

896
897
898
899
900
901
902
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
903
                    log.info(f"deleting existing collection {collection_name}")
904
905
                    CHROMA_CLIENT.delete_collection(name=collection_name)

906
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
907

Timothy J. Baek's avatar
Timothy J. Baek committed
908
        embedding_func = get_embedding_function(
909
910
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
911
            app.state.sentence_transformer_ef,
912
913
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
914
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Steven Kreitzer's avatar
Steven Kreitzer committed
915
916
917
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
918
        embeddings = embedding_func(embedding_texts)
919
920
921

        for batch in create_batches(
            api=CHROMA_CLIENT,
922
            ids=[str(uuid.uuid4()) for _ in texts],
923
924
925
926
927
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
928

929
        return True
930
    except Exception as e:
931
        log.exception(e)
932
933
934
935
936
937
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


938
939
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
988
        loader = PyPDFLoader(
989
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
990
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
991
992
993
994
995
996
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
997
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
998
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
999
1000
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
1001
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1002
1003
        loader = UnstructuredEPubLoader(file_path)
    elif (
1004
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1005
1006
1007
1008
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
1009
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1010
1011
1012
1013
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1014
1015
1016
1017
1018
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
1019
1020
1021
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
1022
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1023
    else:
1024
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1025
1026
1027
1028
1029
        known_type = False

    return loader, known_type


1030
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
1031
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
1032
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
1033
1034
1035
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
1036
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
1037

1038
    log.info(f"file.content_type: {file.content_type}")
1039
    try:
1040
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
1041
        filename = os.path.basename(unsanitized_filename)
1042

Timothy J. Baek's avatar
Timothy J. Baek committed
1043
        file_path = f"{UPLOAD_DIR}/{filename}"
1044

1045
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
1046
        with open(file_path, "wb") as f:
1047
1048
1049
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1050
1051
1052
1053
1054
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1055
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1056
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
1069
1070
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
1071
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
1072
            )
1073
    except Exception as e:
1074
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
1085
1086


1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
1103
1104
1105
1106
1107
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


1118
1119
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
1120
1121
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
1122
1123
1124
1125
1126
1127
1128
1129
1130
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1131
1132
1133
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
1134
1135
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
1160
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
1161
1162
1163
1164
1165
1166
1167
1168
1169
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
1170
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
1171
                    pass
1172

1173
        except Exception as e:
1174
            log.exception(e)
1175
1176
1177
1178

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1179
@app.get("/reset/db")
1180
1181
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
1182
1183


Timothy J. Baek's avatar
Timothy J. Baek committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
@app.get("/reset/uploads")
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    try:
        # Check if the directory exists
        if os.path.exists(folder):
            # Iterate over all the files and directories in the specified directory
            for filename in os.listdir(folder):
                file_path = os.path.join(folder, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)  # Remove the file or link
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)  # Remove the directory
                except Exception as e:
                    print(f"Failed to delete {file_path}. Reason: {e}")
        else:
            print(f"The directory {folder} does not exist")
    except Exception as e:
        print(f"Failed to process the directory {folder}. Reason: {e}")

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1208
@app.get("/reset")
1209
1210
1211
1212
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
1213
        try:
1214
1215
1216
1217
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1218
        except Exception as e:
1219
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1220

1221
1222
1223
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1224
        log.exception(e)
1225
1226

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237


if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}