main.py 40.5 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
mindspawn's avatar
mindspawn committed
12
from datetime import datetime
13
14

from pathlib import Path
15
from typing import List, Union, Sequence
Timothy J. Baek's avatar
Timothy J. Baek committed
16

17
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
18

Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
21
22
23
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
24
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
25
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
26
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
28
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
29
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
31
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
    YoutubeLoader,
mindspawn's avatar
mindspawn committed
34
    OutlookMessageLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
35
)
36
37
from langchain.text_splitter import RecursiveCharacterTextSplitter

38
39
40
41
42
import validators
import urllib.parse
import socket


43
44
from pydantic import BaseModel
from typing import Optional
45
import mimetypes
46
import uuid
47
48
import json

49
import sentence_transformers
50

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
51
from apps.webui.models.documents import (
52
53
54
55
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
56

57
from apps.rag.utils import (
58
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
59
60
61
62
63
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
64
)
Timothy J. Baek's avatar
Timothy J. Baek committed
65

Timothy J. Baek's avatar
Timothy J. Baek committed
66
67
68
69
70
71
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
72
from apps.rag.search.serply import search_serply
Timothy J. Baek's avatar
Timothy J. Baek committed
73

74
75
76
77
78
79
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
80
from utils.utils import get_current_user, get_admin_user
81

82
from config import (
83
    AppConfig,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
84
    ENV,
85
    SRC_LOG_LEVELS,
86
87
    UPLOAD_DIR,
    DOCS_DIR,
88
89
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
90
    RAG_EMBEDDING_ENGINE,
91
    RAG_EMBEDDING_MODEL,
92
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
93
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
94
    ENABLE_RAG_HYBRID_SEARCH,
95
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
96
    RAG_RERANKING_MODEL,
97
    PDF_EXTRACT_IMAGES,
98
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
99
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
100
101
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
102
    DEVICE_TYPE,
103
104
105
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
106
    RAG_TEMPLATE,
107
    ENABLE_RAG_LOCAL_WEB_FETCH,
108
    YOUTUBE_LOADER_LANGUAGE,
Timothy J. Baek's avatar
Timothy J. Baek committed
109
    ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
110
    RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
111
112
113
    SEARXNG_QUERY_URL,
    GOOGLE_PSE_API_KEY,
    GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
114
    BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
    SERPSTACK_API_KEY,
    SERPSTACK_HTTPS,
    SERPER_API_KEY,
118
    SERPLY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
119
    RAG_WEB_SEARCH_RESULT_COUNT,
120
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
121
    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
122
)
123

124
125
from constants import ERROR_MESSAGES

126
127
128
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
app = FastAPI()

131
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
132

133
134
135
136
137
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
138
139
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
140

141
142
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
143

144
145
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
146
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
147
148
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
149

150

151
152
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
153

154
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
155

Steven Kreitzer's avatar
Steven Kreitzer committed
156

157
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
158
159
160
app.state.YOUTUBE_LOADER_TRANSLATION = None


Timothy J. Baek's avatar
Timothy J. Baek committed
161
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE

Timothy J. Baek's avatar
Timothy J. Baek committed
164
165
166
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
Timothy J. Baek's avatar
Timothy J. Baek committed
167
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
168
169
170
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
171
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
172
173
174
175
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS


176
177
178
179
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
180
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
205
    app.state.config.RAG_EMBEDDING_MODEL,
206
207
208
209
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
210
    app.state.config.RAG_RERANKING_MODEL,
211
212
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
213

Timothy J. Baek's avatar
Timothy J. Baek committed
214
215

app.state.EMBEDDING_FUNCTION = get_embedding_function(
216
217
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
218
    app.state.sentence_transformer_ef,
219
220
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
221
    app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
222
223
)

Timothy J. Baek's avatar
Timothy J. Baek committed
224
225
origins = ["*"]

226

Timothy J. Baek's avatar
Timothy J. Baek committed
227
228
229
230
231
232
233
234
235
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
236
class CollectionNameForm(BaseModel):
237
238
239
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
240
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
241
242
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
243

244
245
246
247
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
248
249
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
250
251
    return {
        "status": True,
252
253
254
255
256
257
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
258
        "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
259
260
261
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
262
263
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
264
265
    return {
        "status": True,
266
267
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
268
        "openai_config": {
269
270
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
271
            "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
272
        },
273
274
275
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
276
277
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
278
279
    return {
        "status": True,
280
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
281
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
282
283


284
285
286
class OpenAIConfigForm(BaseModel):
    url: str
    key: str
287
    batch_size: Optional[int] = None
288
289


290
class EmbeddingModelUpdateForm(BaseModel):
291
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
292
    embedding_engine: str
293
294
295
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
296
297
@app.post("/embedding/update")
async def update_embedding_config(
298
299
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
300
    log.info(
301
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
302
    )
303
    try:
304
305
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
306

307
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
308
            if form_data.openai_config is not None:
309
310
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
311
312
313
314
315
                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
                    form_data.openai_config.batch_size
                    if form_data.openai_config.batch_size
                    else 1
                )
316

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
317
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
318

Timothy J. Baek's avatar
Timothy J. Baek committed
319
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
320
321
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
322
            app.state.sentence_transformer_ef,
323
324
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
325
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
326
327
        )

328
329
        return {
            "status": True,
330
331
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
332
            "openai_config": {
333
334
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
335
                "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
336
            },
337
338
339
340
341
342
343
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
344
345


Steven Kreitzer's avatar
Steven Kreitzer committed
346
347
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
348

Steven Kreitzer's avatar
Steven Kreitzer committed
349
350
351
352
353
354

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
355
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
356
357
    )
    try:
358
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
359

360
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
361
362
363

        return {
            "status": True,
364
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
365
366
367
368
369
370
371
372
373
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
374
375
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
376
377
    return {
        "status": True,
378
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
379
        "chunk": {
380
381
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
382
        },
383
        "youtube": {
384
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
385
386
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
387
        "web": {
Timothy J. Baek's avatar
Timothy J. Baek committed
388
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
389
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
390
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
391
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
392
393
394
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
395
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
396
397
398
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
399
                "serply_api_key": app.state.config.SERPLY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
400
401
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
Timothy J. Baek's avatar
Timothy J. Baek committed
402
            },
Timothy J. Baek's avatar
Timothy J. Baek committed
403
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
404
405
406
407
408
409
410
411
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


412
413
414
415
416
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
417
class WebSearchConfig(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
418
    enabled: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
419
    engine: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
420
421
422
    searxng_query_url: Optional[str] = None
    google_pse_api_key: Optional[str] = None
    google_pse_engine_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
423
    brave_search_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
424
425
426
    serpstack_api_key: Optional[str] = None
    serpstack_https: Optional[bool] = None
    serper_api_key: Optional[str] = None
427
    serply_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
428
429
430
431
    result_count: Optional[int] = None
    concurrent_requests: Optional[int] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
432
433
434
435
436
class WebConfig(BaseModel):
    search: WebSearchConfig
    web_loader_ssl_verification: Optional[bool] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
437
class ConfigUpdateForm(BaseModel):
438
439
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
440
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
441
    web: Optional[WebConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
442
443
444
445


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
446
    app.state.config.PDF_EXTRACT_IMAGES = (
447
        form_data.pdf_extract_images
448
449
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
450
451
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
452
453
454
    if form_data.chunk is not None:
        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
455

Timothy J. Baek's avatar
Timothy J. Baek committed
456
457
458
    if form_data.youtube is not None:
        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
459

Timothy J. Baek's avatar
Timothy J. Baek committed
460
461
462
463
    if form_data.web is not None:
        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
            form_data.web.web_loader_ssl_verification
        )
464

Timothy J. Baek's avatar
Timothy J. Baek committed
465
        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
Timothy J. Baek's avatar
Timothy J. Baek committed
466
467
468
469
470
471
472
473
474
475
476
477
        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
        app.state.config.GOOGLE_PSE_ENGINE_ID = (
            form_data.web.search.google_pse_engine_id
        )
        app.state.config.BRAVE_SEARCH_API_KEY = (
            form_data.web.search.brave_search_api_key
        )
        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
478
        app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
Timothy J. Baek's avatar
Timothy J. Baek committed
479
480
481
482
        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
            form_data.web.search.concurrent_requests
        )
483

Timothy J. Baek's avatar
Timothy J. Baek committed
484
485
    return {
        "status": True,
486
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
487
        "chunk": {
488
489
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
490
        },
491
        "youtube": {
492
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
493
494
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
495
496
497
        "web": {
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
            "search": {
Timothy J. Baek's avatar
Timothy J. Baek committed
498
                "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
499
500
501
502
503
504
505
506
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
507
                "serply_api_key": app.state.config.SERPLY_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
508
509
510
511
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
            },
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
512
    }
513
514


Timothy J. Baek's avatar
Timothy J. Baek committed
515
516
517
518
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
519
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
520
521
522
    }


523
524
525
526
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
527
528
529
530
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
531
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
532
533


534
535
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
536
    r: Optional[float] = None
537
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
538
    hybrid: Optional[bool] = None
539
540
541
542
543
544


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
545
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
546
        form_data.template if form_data.template else RAG_TEMPLATE
547
    )
548
549
550
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
551
        form_data.hybrid if form_data.hybrid else False
552
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
553
554
    return {
        "status": True,
555
556
557
558
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
559
    }
560
561


562
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
563
564
    collection_name: str
    query: str
565
    k: Optional[int] = None
566
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
567
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
568
569


570
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
571
def query_doc_handler(
572
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
573
574
    user=Depends(get_current_user),
):
575
    try:
576
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
577
578
579
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
580
                embedding_function=app.state.EMBEDDING_FUNCTION,
581
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
582
                reranking_function=app.state.sentence_transformer_rf,
583
                r=(
584
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
585
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
586
587
588
589
590
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
591
                embedding_function=app.state.EMBEDDING_FUNCTION,
592
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
593
            )
594
    except Exception as e:
595
        log.exception(e)
596
597
598
599
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
600
601


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
602
603
604
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
605
    k: Optional[int] = None
606
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
607
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
608
609


610
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
611
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
612
613
614
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
615
    try:
616
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
617
618
619
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
620
                embedding_function=app.state.EMBEDDING_FUNCTION,
621
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
622
                reranking_function=app.state.sentence_transformer_rf,
623
                r=(
624
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
625
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
626
627
628
629
630
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
631
                embedding_function=app.state.EMBEDDING_FUNCTION,
632
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
633
            )
634

635
636
637
638
639
640
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
641
642


Timothy J. Baek's avatar
Timothy J. Baek committed
643
644
645
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
646
647
648
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
649
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
650
651
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


672
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
673
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
674
675
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
676
        loader = get_web_loader(
677
            form_data.url,
678
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
679
        )
680
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
681
682
683
684
685

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

686
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
687
688
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
689
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
690
691
            "filename": form_data.url,
        }
692
    except Exception as e:
693
        log.exception(e)
694
695
696
697
698
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

699

700
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
701
    # Check if the URL is valid
702
    if not validate_url(url):
703
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
704
705
706
707
    return WebBaseLoader(
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
708
        continue_on_failure=True,
709
    )
710
711


712
713
714
715
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
716
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False


736
737
738
739
740
741
742
743
744
745
746
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


Timothy J. Baek's avatar
Timothy J. Baek committed
747
748
749
750
751
752
753
754
def search_web(engine: str, query: str) -> list[SearchResult]:
    """Search the web using a search engine and return the results as a list of SearchResult objects.
    Will look for a search engine API key in environment variables in the following order:
    - SEARXNG_QUERY_URL
    - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
    - BRAVE_SEARCH_API_KEY
    - SERPSTACK_API_KEY
    - SERPER_API_KEY
755
    - SERPLY_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
756
757
758
759
760
761
762
763

    Args:
        query (str): The query to search for
    """

    # TODO: add playwright to search the web
    if engine == "searxng":
        if app.state.config.SEARXNG_QUERY_URL:
Timothy J. Baek's avatar
Timothy J. Baek committed
764
765
766
767
768
            return search_searxng(
                app.state.config.SEARXNG_QUERY_URL,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
769
770
771
772
773
774
775
776
777
778
779
        else:
            raise Exception("No SEARXNG_QUERY_URL found in environment variables")
    elif engine == "google_pse":
        if (
            app.state.config.GOOGLE_PSE_API_KEY
            and app.state.config.GOOGLE_PSE_ENGINE_ID
        ):
            return search_google_pse(
                app.state.config.GOOGLE_PSE_API_KEY,
                app.state.config.GOOGLE_PSE_ENGINE_ID,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
780
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
781
782
783
784
785
786
787
            )
        else:
            raise Exception(
                "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
            )
    elif engine == "brave":
        if app.state.config.BRAVE_SEARCH_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
788
789
790
791
792
            return search_brave(
                app.state.config.BRAVE_SEARCH_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
793
794
795
796
797
798
799
        else:
            raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
    elif engine == "serpstack":
        if app.state.config.SERPSTACK_API_KEY:
            return search_serpstack(
                app.state.config.SERPSTACK_API_KEY,
                query,
Timothy J. Baek's avatar
Timothy J. Baek committed
800
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
Timothy J. Baek's avatar
Timothy J. Baek committed
801
802
803
804
805
806
                https_enabled=app.state.config.SERPSTACK_HTTPS,
            )
        else:
            raise Exception("No SERPSTACK_API_KEY found in environment variables")
    elif engine == "serper":
        if app.state.config.SERPER_API_KEY:
Timothy J. Baek's avatar
Timothy J. Baek committed
807
808
809
810
811
            return search_serper(
                app.state.config.SERPER_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
812
813
        else:
            raise Exception("No SERPER_API_KEY found in environment variables")
814
815
816
817
818
819
820
821
822
    elif engine == "serply":
        if app.state.config.SERPLY_API_KEY:
            return search_serply(
                app.state.config.SERPLY_API_KEY,
                query,
                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
            )
        else:
            raise Exception("No SERPLY_API_KEY found in environment variables")
Timothy J. Baek's avatar
Timothy J. Baek committed
823
824
825
826
    else:
        raise Exception("No search engine API key found in environment variables")


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
827
828
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
829
    try:
830
        logging.info(f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}")
Timothy J. Baek's avatar
Timothy J. Baek committed
831
832
833
834
835
836
837
838
839
840
841
842
843
        web_results = search_web(
            app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
        )
    except Exception as e:
        log.exception(e)

        print(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
        )

    try:
844
845
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
846
        data = loader.load()
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


866
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
867

868
    text_splitter = RecursiveCharacterTextSplitter(
869
870
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
871
872
        add_start_index=True,
    )
873

874
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
875
876

    if len(docs) > 0:
877
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
878
879
880
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
881
882
883


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
884
    text, metadata, collection_name, overwrite: bool = False
885
886
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
887
888
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
889
890
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
891
    docs = text_splitter.create_documents([text], metadatas=[metadata])
892
893
894
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
895
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
896
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
897

898
899
900
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

mindspawn's avatar
mindspawn committed
901
902
903
904
905
906
907
    # ChromaDB does not like datetime formats
    # for meta-data so convert them to string.
    for metadata in metadatas:
        for key, value in metadata.items():
            if isinstance(value, datetime):
                metadata[key] = str(value)

908
909
910
911
    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
912
                    log.info(f"deleting existing collection {collection_name}")
913
914
                    CHROMA_CLIENT.delete_collection(name=collection_name)

915
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
916

Timothy J. Baek's avatar
Timothy J. Baek committed
917
        embedding_func = get_embedding_function(
918
919
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
920
            app.state.sentence_transformer_ef,
921
922
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
923
            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
Steven Kreitzer's avatar
Steven Kreitzer committed
924
925
926
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
927
        embeddings = embedding_func(embedding_texts)
928
929
930

        for batch in create_batches(
            api=CHROMA_CLIENT,
931
            ids=[str(uuid.uuid4()) for _ in texts],
932
933
934
935
936
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
937

938
        return True
939
    except Exception as e:
940
        log.exception(e)
941
942
943
944
945
946
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


947
948
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
mindspawn's avatar
mindspawn committed
994
        "msg",
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
995
996
997
    ]

    if file_ext == "pdf":
998
        loader = PyPDFLoader(
999
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
1000
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1001
1002
1003
1004
1005
1006
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
1007
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
1008
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1009
1010
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
1011
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1012
1013
        loader = UnstructuredEPubLoader(file_path)
    elif (
1014
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1015
1016
1017
1018
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
1019
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1020
1021
1022
1023
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1024
1025
1026
1027
1028
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
mindspawn's avatar
mindspawn committed
1029
1030
    elif file_ext == "msg":
        loader = OutlookMessageLoader(file_path)
1031
1032
1033
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
1034
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1035
    else:
1036
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1037
1038
1039
1040
1041
        known_type = False

    return loader, known_type


1042
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
1043
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
1044
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
1045
1046
1047
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
1048
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
1049

1050
    log.info(f"file.content_type: {file.content_type}")
1051
    try:
1052
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
1053
        filename = os.path.basename(unsanitized_filename)
1054

Timothy J. Baek's avatar
Timothy J. Baek committed
1055
        file_path = f"{UPLOAD_DIR}/{filename}"
1056

1057
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
1058
        with open(file_path, "wb") as f:
1059
1060
1061
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1062
1063
1064
1065
1066
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1067
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1068
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
1081
1082
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
1083
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
1084
            )
1085
    except Exception as e:
1086
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
1097
1098


1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
1115
1116
1117
1118
1119
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


1130
1131
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
1132
1133
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
1134
1135
1136
1137
1138
1139
1140
1141
1142
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1143
1144
1145
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
1146
1147
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
1172
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
1173
1174
1175
1176
1177
1178
1179
1180
1181
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
1182
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
1183
                    pass
1184

1185
        except Exception as e:
1186
            log.exception(e)
1187
1188
1189
1190

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1191
@app.get("/reset/db")
1192
1193
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
1194
1195


Timothy J. Baek's avatar
Timothy J. Baek committed
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
@app.get("/reset/uploads")
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    try:
        # Check if the directory exists
        if os.path.exists(folder):
            # Iterate over all the files and directories in the specified directory
            for filename in os.listdir(folder):
                file_path = os.path.join(folder, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)  # Remove the file or link
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)  # Remove the directory
                except Exception as e:
                    print(f"Failed to delete {file_path}. Reason: {e}")
        else:
            print(f"The directory {folder} does not exist")
    except Exception as e:
        print(f"Failed to process the directory {folder}. Reason: {e}")

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1220
@app.get("/reset")
1221
1222
1223
1224
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
1225
        try:
1226
1227
1228
1229
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1230
        except Exception as e:
1231
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1232

1233
1234
1235
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1236
        log.exception(e)
1237
1238

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249


if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}