main.py 34.6 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
14
from typing import List, Union, Sequence
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
    YoutubeLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
)
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

36
37
38
39
40
import validators
import urllib.parse
import socket


41
42
from pydantic import BaseModel
from typing import Optional
43
import mimetypes
44
import uuid
45
46
import json

47
import sentence_transformers
48

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
49
from apps.webui.models.documents import (
50
51
52
53
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
54

55
from apps.rag.utils import (
56
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
61
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
62
    search_web,
63
)
Timothy J. Baek's avatar
Timothy J. Baek committed
64

65
66
67
68
69
70
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
71
from utils.utils import get_current_user, get_admin_user
72

73
from config import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
74
    ENV,
75
    SRC_LOG_LEVELS,
76
77
    UPLOAD_DIR,
    DOCS_DIR,
78
79
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
80
    RAG_EMBEDDING_ENGINE,
81
    RAG_EMBEDDING_MODEL,
82
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
83
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
84
    ENABLE_RAG_HYBRID_SEARCH,
85
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
86
    RAG_RERANKING_MODEL,
87
    PDF_EXTRACT_IMAGES,
88
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
89
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
92
    DEVICE_TYPE,
93
94
95
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
96
    RAG_TEMPLATE,
97
    ENABLE_RAG_LOCAL_WEB_FETCH,
98
    YOUTUBE_LOADER_LANGUAGE,
Timothy J. Baek's avatar
Timothy J. Baek committed
99
    ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
100
    RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
101
102
103
    SEARXNG_QUERY_URL,
    GOOGLE_PSE_API_KEY,
    GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
104
    BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
105
106
107
108
    SERPSTACK_API_KEY,
    SERPSTACK_HTTPS,
    SERPER_API_KEY,
    RAG_WEB_SEARCH_RESULT_COUNT,
109
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
110
    AppConfig,
111
)
112

113
114
from constants import ERROR_MESSAGES

115
116
117
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
118
119
app = FastAPI()

120
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
121

122
123
124
125
126
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
127
128
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
129

130
131
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
132

133
134
135
136
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
137

138

139
140
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
141

142
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
143

Steven Kreitzer's avatar
Steven Kreitzer committed
144

145
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
146
147
148
app.state.YOUTUBE_LOADER_TRANSLATION = None


Timothy J. Baek's avatar
Timothy J. Baek committed
149
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
Timothy J. Baek's avatar
Timothy J. Baek committed
150
151
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE

Timothy J. Baek's avatar
Timothy J. Baek committed
152
153
154
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
Timothy J. Baek's avatar
Timothy J. Baek committed
155
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
158
159
160
161
162
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS


163
164
165
166
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
167
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
192
    app.state.config.RAG_EMBEDDING_MODEL,
193
194
195
196
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
197
    app.state.config.RAG_RERANKING_MODEL,
198
199
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
200

Timothy J. Baek's avatar
Timothy J. Baek committed
201
202

app.state.EMBEDDING_FUNCTION = get_embedding_function(
203
204
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
205
    app.state.sentence_transformer_ef,
206
207
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
208
209
)

Timothy J. Baek's avatar
Timothy J. Baek committed
210
211
origins = ["*"]

212

Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
215
216
217
218
219
220
221
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
222
class CollectionNameForm(BaseModel):
223
224
225
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
226
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
227
228
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
229

230
231
232
233
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
234
235
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
236
237
    return {
        "status": True,
238
239
240
241
242
243
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
244
245
246
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
247
248
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
249
250
    return {
        "status": True,
251
252
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
253
        "openai_config": {
254
255
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
256
        },
257
258
259
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
260
261
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
262
263
    return {
        "status": True,
264
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
265
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
266
267


268
269
270
271
272
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


273
class EmbeddingModelUpdateForm(BaseModel):
274
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
275
    embedding_engine: str
276
277
278
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
279
280
@app.post("/embedding/update")
async def update_embedding_config(
281
282
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
283
    log.info(
284
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
285
    )
286
    try:
287
288
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
289

290
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
291
            if form_data.openai_config != None:
292
293
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
294

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
295
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
296

Timothy J. Baek's avatar
Timothy J. Baek committed
297
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
298
299
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
300
            app.state.sentence_transformer_ef,
301
302
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
303
304
        )

305
306
        return {
            "status": True,
307
308
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
309
            "openai_config": {
310
311
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
312
            },
313
314
315
316
317
318
319
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
320
321


Steven Kreitzer's avatar
Steven Kreitzer committed
322
323
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
324

Steven Kreitzer's avatar
Steven Kreitzer committed
325
326
327
328
329
330

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
331
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
332
333
    )
    try:
334
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
335

336
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
337
338
339

        return {
            "status": True,
340
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
341
342
343
344
345
346
347
348
349
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
350
351
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
352
353
    return {
        "status": True,
354
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
355
        "chunk": {
356
357
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
358
        },
359
        "youtube": {
360
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
361
362
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
363
        "web": {
Timothy J. Baek's avatar
Timothy J. Baek committed
364
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
365
366
            "search": {
                "enable": app.state.config.ENABLE_RAG_WEB_SEARCH,
Timothy J. Baek's avatar
Timothy J. Baek committed
367
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
Timothy J. Baek's avatar
Timothy J. Baek committed
368
369
370
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
Timothy J. Baek's avatar
Timothy J. Baek committed
371
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
372
373
374
375
376
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
Timothy J. Baek's avatar
Timothy J. Baek committed
377
            },
Timothy J. Baek's avatar
Timothy J. Baek committed
378
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
379
380
381
382
383
384
385
386
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


387
388
389
390
391
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
392
393
class WebSearchConfig(BaseModel):
    enable: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
394
    engine: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
395
396
397
    searxng_query_url: Optional[str] = None
    google_pse_api_key: Optional[str] = None
    google_pse_engine_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
398
    brave_search_api_key: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
399
400
401
402
403
404
405
    serpstack_api_key: Optional[str] = None
    serpstack_https: Optional[bool] = None
    serper_api_key: Optional[str] = None
    result_count: Optional[int] = None
    concurrent_requests: Optional[int] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
406
407
408
409
410
class WebConfig(BaseModel):
    search: WebSearchConfig
    web_loader_ssl_verification: Optional[bool] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
411
class ConfigUpdateForm(BaseModel):
412
413
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
414
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
415
    web: Optional[WebConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
416
417
418
419


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
420
    app.state.config.PDF_EXTRACT_IMAGES = (
421
        form_data.pdf_extract_images
422
423
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
424
425
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
426
427
428
    if form_data.chunk is not None:
        app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
        app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
429

Timothy J. Baek's avatar
Timothy J. Baek committed
430
431
432
    if form_data.youtube is not None:
        app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
        app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
433

Timothy J. Baek's avatar
Timothy J. Baek committed
434
435
436
437
    if form_data.web is not None:
        app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
            form_data.web.web_loader_ssl_verification
        )
438

Timothy J. Baek's avatar
Timothy J. Baek committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enable
        app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
        app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
        app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
        app.state.config.GOOGLE_PSE_ENGINE_ID = (
            form_data.web.search.google_pse_engine_id
        )
        app.state.config.BRAVE_SEARCH_API_KEY = (
            form_data.web.search.brave_search_api_key
        )
        app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
        app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
        app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
        app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
        app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
            form_data.web.search.concurrent_requests
        )
456

Timothy J. Baek's avatar
Timothy J. Baek committed
457
458
    return {
        "status": True,
459
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
460
        "chunk": {
461
462
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
463
        },
464
        "youtube": {
465
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
466
467
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        "web": {
            "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
            "search": {
                "enable": app.state.config.ENABLE_RAG_WEB_SEARCH,
                "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
                "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
                "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
                "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
                "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
                "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
                "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                "serper_api_key": app.state.config.SERPER_API_KEY,
                "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
            },
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
484
    }
485
486


Timothy J. Baek's avatar
Timothy J. Baek committed
487
488
489
490
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
491
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
492
493
494
    }


495
496
497
498
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
499
500
501
502
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
503
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
504
505


506
507
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
508
    r: Optional[float] = None
509
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
510
    hybrid: Optional[bool] = None
511
512
513
514
515
516


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
517
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
518
        form_data.template if form_data.template else RAG_TEMPLATE
519
    )
520
521
522
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
523
        form_data.hybrid if form_data.hybrid else False
524
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
525
526
    return {
        "status": True,
527
528
529
530
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
531
    }
532
533


534
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
535
536
    collection_name: str
    query: str
537
    k: Optional[int] = None
538
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
539
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
540
541


542
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
543
def query_doc_handler(
544
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
545
546
    user=Depends(get_current_user),
):
547
    try:
548
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
549
550
551
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
552
                embedding_function=app.state.EMBEDDING_FUNCTION,
553
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
554
                reranking_function=app.state.sentence_transformer_rf,
555
                r=(
556
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
557
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
558
559
560
561
562
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
563
                embedding_function=app.state.EMBEDDING_FUNCTION,
564
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
565
            )
566
    except Exception as e:
567
        log.exception(e)
568
569
570
571
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
572
573


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
574
575
576
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
577
    k: Optional[int] = None
578
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
579
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
580
581


582
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
583
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
584
585
586
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
587
    try:
588
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
589
590
591
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
592
                embedding_function=app.state.EMBEDDING_FUNCTION,
593
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
594
                reranking_function=app.state.sentence_transformer_rf,
595
                r=(
596
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
597
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
598
599
600
601
602
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
603
                embedding_function=app.state.EMBEDDING_FUNCTION,
604
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
605
            )
606

607
608
609
610
611
612
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
613
614


Timothy J. Baek's avatar
Timothy J. Baek committed
615
616
617
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
618
619
620
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
621
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
622
623
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


644
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
645
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
646
647
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
648
        loader = get_web_loader(
649
            form_data.url,
650
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
651
        )
652
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
653
654
655
656
657

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

658
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
659
660
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
661
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
662
663
            "filename": form_data.url,
        }
664
    except Exception as e:
665
        log.exception(e)
666
667
668
669
670
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

671

672
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
673
    # Check if the URL is valid
674
    if not validate_url(url):
675
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
676
677
678
679
    return WebBaseLoader(
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
680
        continue_on_failure=True,
681
    )
682
683


684
685
686
687
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
688
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False


708
709
710
711
712
713
714
715
716
717
718
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
719
720
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
721
    try:
722
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
723
724
725
            web_results = search_web(
                app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
            )
726
727
728
729
730
731
        except Exception as e:
            log.exception(e)
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
            )
732
733
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
734
        data = loader.load()
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


754
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
755

756
    text_splitter = RecursiveCharacterTextSplitter(
757
758
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
759
760
        add_start_index=True,
    )
761

762
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
763
764

    if len(docs) > 0:
765
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
766
767
768
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
769
770
771


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
772
    text, metadata, collection_name, overwrite: bool = False
773
774
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
775
776
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
777
778
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
779
    docs = text_splitter.create_documents([text], metadatas=[metadata])
780
781
782
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
783
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
784
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
785

786
787
788
789
790
791
792
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
793
                    log.info(f"deleting existing collection {collection_name}")
794
795
                    CHROMA_CLIENT.delete_collection(name=collection_name)

796
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
797

Timothy J. Baek's avatar
Timothy J. Baek committed
798
        embedding_func = get_embedding_function(
799
800
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
801
            app.state.sentence_transformer_ef,
802
803
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Steven Kreitzer's avatar
Steven Kreitzer committed
804
805
806
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
807
        embeddings = embedding_func(embedding_texts)
808
809
810

        for batch in create_batches(
            api=CHROMA_CLIENT,
811
            ids=[str(uuid.uuid4()) for _ in texts],
812
813
814
815
816
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
817

818
        return True
819
    except Exception as e:
820
        log.exception(e)
821
822
823
824
825
826
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


827
828
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
877
        loader = PyPDFLoader(
878
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
879
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
880
881
882
883
884
885
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
886
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
887
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
888
889
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
890
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
891
892
        loader = UnstructuredEPubLoader(file_path)
    elif (
893
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
894
895
896
897
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
898
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
899
900
901
902
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
903
904
905
906
907
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
908
909
910
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
911
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
912
    else:
913
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
914
915
916
917
918
        known_type = False

    return loader, known_type


919
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
920
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
921
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
922
923
924
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
925
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
926

927
    log.info(f"file.content_type: {file.content_type}")
928
    try:
929
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
930
        filename = os.path.basename(unsanitized_filename)
931

Timothy J. Baek's avatar
Timothy J. Baek committed
932
        file_path = f"{UPLOAD_DIR}/{filename}"
933

934
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
935
        with open(file_path, "wb") as f:
936
937
938
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
939
940
941
942
943
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
944
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
945
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
946
947
948
949
950
951
952
953
954
955
956
957

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
958
959
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
960
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
961
            )
962
    except Exception as e:
963
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
964
965
966
967
968
969
970
971
972
973
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
974
975


976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
992
993
994
995
996
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
997
998
999
1000
1001
1002
1003
1004
1005
1006

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


1007
1008
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
1009
1010
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
1011
1012
1013
1014
1015
1016
1017
1018
1019
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
1020
1021
1022
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
1023
1024
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
1049
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
1050
1051
1052
1053
1054
1055
1056
1057
1058
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
1059
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
1060
                    pass
1061

1062
        except Exception as e:
1063
            log.exception(e)
1064
1065
1066
1067

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
1068
@app.get("/reset/db")
1069
1070
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
1071
1072
1073


@app.get("/reset")
1074
1075
1076
1077
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
1078
        try:
1079
1080
1081
1082
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1083
        except Exception as e:
1084
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1085

1086
1087
1088
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1089
        log.exception(e)
1090
1091

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102


if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}