main.py 28.2 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
14
from typing import List, Union, Sequence
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    YoutubeLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
)
33
34
from langchain.text_splitter import RecursiveCharacterTextSplitter

35
36
37
38
39
import validators
import urllib.parse
import socket


40
41
from pydantic import BaseModel
from typing import Optional
42
import mimetypes
43
import uuid
44
45
import json

46
import sentence_transformers
47

48
49
50
51
52
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
53

54
from apps.rag.utils import (
55
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57
58
59
60
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
61
    search_web,
62
)
Timothy J. Baek's avatar
Timothy J. Baek committed
63

64
65
66
67
68
69
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
70
from utils.utils import get_current_user, get_admin_user
71

72
from config import (
73
    SRC_LOG_LEVELS,
74
75
    UPLOAD_DIR,
    DOCS_DIR,
76
77
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
78
    RAG_EMBEDDING_ENGINE,
79
    RAG_EMBEDDING_MODEL,
80
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
81
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
82
    ENABLE_RAG_HYBRID_SEARCH,
83
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
84
    RAG_RERANKING_MODEL,
85
    PDF_EXTRACT_IMAGES,
86
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
87
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
88
89
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
90
    DEVICE_TYPE,
91
92
93
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
94
    RAG_TEMPLATE,
95
    ENABLE_RAG_LOCAL_WEB_FETCH,
96
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
97
)
98

99
100
from constants import ERROR_MESSAGES

101
102
103
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
app = FastAPI()

106
107
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
Timothy J. Baek's avatar
Timothy J. Baek committed
108
109

app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
110
111
112
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
113

Timothy J. Baek's avatar
Timothy J. Baek committed
114
115
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
116

117
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
118
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
Steven Kreitzer's avatar
Steven Kreitzer committed
119
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
120
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
121

122

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
123
124
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
125

126
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
127

Steven Kreitzer's avatar
Steven Kreitzer committed
128

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
    if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
    app.state.RAG_EMBEDDING_MODEL,
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
    app.state.RAG_RERANKING_MODEL,
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
166

Timothy J. Baek's avatar
Timothy J. Baek committed
167
168
169
170
171
172
173
174
175

app.state.EMBEDDING_FUNCTION = get_embedding_function(
    app.state.RAG_EMBEDDING_ENGINE,
    app.state.RAG_EMBEDDING_MODEL,
    app.state.sentence_transformer_ef,
    app.state.OPENAI_API_KEY,
    app.state.OPENAI_API_BASE_URL,
)

Timothy J. Baek's avatar
Timothy J. Baek committed
176
177
origins = ["*"]

178

Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
181
182
183
184
185
186
187
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
188
class CollectionNameForm(BaseModel):
189
190
191
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
192
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
193
194
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
195

196
197
198
199
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
200
201
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
202
203
204
205
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
206
        "template": app.state.RAG_TEMPLATE,
207
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
208
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
209
        "reranking_model": app.state.RAG_RERANKING_MODEL,
210
211
212
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
215
216
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
217
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
218
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
219
        "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
220
221
            "url": app.state.OPENAI_API_BASE_URL,
            "key": app.state.OPENAI_API_KEY,
222
        },
223
224
225
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
226
227
228
229
230
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
    return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}


231
232
233
234
235
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


236
class EmbeddingModelUpdateForm(BaseModel):
237
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
238
    embedding_engine: str
239
240
241
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
242
243
@app.post("/embedding/update")
async def update_embedding_config(
244
245
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
246
247
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
248
    )
249
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
250
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
251
        app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
252

253
254
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
            if form_data.openai_config != None:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
255
256
                app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.OPENAI_API_KEY = form_data.openai_config.key
257

258
        update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
259

Timothy J. Baek's avatar
Timothy J. Baek committed
260
261
262
263
264
265
266
267
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )

268
269
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
270
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
271
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
272
            "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
273
274
                "url": app.state.OPENAI_API_BASE_URL,
                "key": app.state.OPENAI_API_KEY,
275
            },
276
277
278
279
280
281
282
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
283
284


Steven Kreitzer's avatar
Steven Kreitzer committed
285
286
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
287

Steven Kreitzer's avatar
Steven Kreitzer committed
288
289
290
291
292
293
294
295
296
297

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
        f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
    )
    try:
        app.state.RAG_RERANKING_MODEL = form_data.reranking_model
298

299
        update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
Steven Kreitzer's avatar
Steven Kreitzer committed
300
301
302
303
304
305
306
307
308
309
310
311
312

        return {
            "status": True,
            "reranking_model": app.state.RAG_RERANKING_MODEL,
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
313
314
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
315
316
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
317
318
319
320
321
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
322
        "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
323
324
325
326
327
328
329
330
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


Timothy J. Baek's avatar
Timothy J. Baek committed
331
class ConfigUpdateForm(BaseModel):
332
333
334
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
    web_loader_ssl_verification: Optional[bool] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
335
336
337
338


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    app.state.PDF_EXTRACT_IMAGES = (
        form_data.pdf_extract_images
        if form_data.pdf_extract_images != None
        else app.state.PDF_EXTRACT_IMAGES
    )

    app.state.CHUNK_SIZE = (
        form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
    )

    app.state.CHUNK_OVERLAP = (
        form_data.chunk.chunk_overlap
        if form_data.chunk != None
        else app.state.CHUNK_OVERLAP
    )

    app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
        form_data.web_loader_ssl_verification
        if form_data.web_loader_ssl_verification != None
        else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
360
361
362

    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
363
364
365
366
367
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
368
        "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Timothy J. Baek's avatar
Timothy J. Baek committed
369
    }
370
371


Timothy J. Baek's avatar
Timothy J. Baek committed
372
373
374
375
376
377
378
379
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


380
381
382
383
384
385
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
386
        "r": app.state.RELEVANCE_THRESHOLD,
Timothy J. Baek's avatar
Timothy J. Baek committed
387
        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
388
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
389
390


391
392
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
393
    r: Optional[float] = None
394
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
395
    hybrid: Optional[bool] = None
396
397
398
399
400
401
402
403


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
404
    app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
Timothy J. Baek's avatar
Timothy J. Baek committed
405
    app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
Steven Kreitzer's avatar
Steven Kreitzer committed
406
407
408
409
410
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
        "r": app.state.RELEVANCE_THRESHOLD,
Timothy J. Baek's avatar
Timothy J. Baek committed
411
        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
412
    }
413
414


415
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
416
417
    collection_name: str
    query: str
418
    k: Optional[int] = None
419
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
420
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
421
422


423
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
424
def query_doc_handler(
425
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
426
427
    user=Depends(get_current_user),
):
428
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
429
430
431
432
        if app.state.ENABLE_RAG_HYBRID_SEARCH:
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
433
                embedding_function=app.state.EMBEDDING_FUNCTION,
Timothy J. Baek's avatar
Timothy J. Baek committed
434
                k=form_data.k if form_data.k else app.state.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
435
                reranking_function=app.state.sentence_transformer_rf,
Timothy J. Baek's avatar
Timothy J. Baek committed
436
437
438
439
440
441
                r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
442
                embedding_function=app.state.EMBEDDING_FUNCTION,
Timothy J. Baek's avatar
Timothy J. Baek committed
443
444
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
445
    except Exception as e:
446
        log.exception(e)
447
448
449
450
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
451
452


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
453
454
455
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
456
    k: Optional[int] = None
457
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
458
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
459
460


461
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
462
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
463
464
465
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
466
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
467
468
469
470
        if app.state.ENABLE_RAG_HYBRID_SEARCH:
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
471
                embedding_function=app.state.EMBEDDING_FUNCTION,
Timothy J. Baek's avatar
Timothy J. Baek committed
472
                k=form_data.k if form_data.k else app.state.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
473
                reranking_function=app.state.sentence_transformer_rf,
Timothy J. Baek's avatar
Timothy J. Baek committed
474
475
476
477
478
479
                r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
480
                embedding_function=app.state.EMBEDDING_FUNCTION,
Timothy J. Baek's avatar
Timothy J. Baek committed
481
482
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
483

484
485
486
487
488
489
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
490
491


Timothy J. Baek's avatar
Timothy J. Baek committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
        loader = YoutubeLoader.from_youtube_url(form_data.url, add_video_info=False)
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


516
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
517
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
518
519
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
520
521
522
        loader = get_web_loader(
            form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
        )
523
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
524
525
526
527
528

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

529
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
530
531
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
532
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
533
534
            "filename": form_data.url,
        }
535
    except Exception as e:
536
        log.exception(e)
537
538
539
540
541
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

542

543
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
544
    # Check if the URL is valid
545
    if not validate_url(url):
546
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
547
548
549
550
551
    return WebBaseLoader(
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
    )
552
553


554
555
556
557
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
558
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False


578
579
580
581
582
583
584
585
586
587
588
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


589
590
591
@app.post("/websearch")
def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
    try:
592
593
594
595
596
597
598
599
        try:
            web_results = search_web(form_data.query)
        except Exception as e:
            log.exception(e)
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
            )
600
601
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
602
        data = loader.aload()
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


622
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
623

624
625
626
627
628
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
629

630
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
631
632

    if len(docs) > 0:
633
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
634
635
636
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
637
638
639


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
640
    text, metadata, collection_name, overwrite: bool = False
641
642
643
644
645
646
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
647
    docs = text_splitter.create_documents([text], metadatas=[metadata])
648
649
650
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
651
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
652
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
653

654
655
656
657
658
659
660
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
661
                    log.info(f"deleting existing collection {collection_name}")
662
663
                    CHROMA_CLIENT.delete_collection(name=collection_name)

664
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
665

Timothy J. Baek's avatar
Timothy J. Baek committed
666
        embedding_func = get_embedding_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
667
668
669
670
671
672
673
674
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
675
        embeddings = embedding_func(embedding_texts)
676
677
678
679
680
681
682
683
684

        for batch in create_batches(
            api=CHROMA_CLIENT,
            ids=[str(uuid.uuid1()) for _ in texts],
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
685

686
        return True
687
    except Exception as e:
688
        log.exception(e)
689
690
691
692
693
694
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


695
696
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
745
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
746
747
748
749
750
751
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
752
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
753
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
754
755
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
756
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
757
758
        loader = UnstructuredEPubLoader(file_path)
    elif (
759
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
760
761
762
763
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
764
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
765
766
767
768
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
769
770
771
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
772
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
773
    else:
774
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
775
776
777
778
779
        known_type = False

    return loader, known_type


780
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
781
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
782
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
783
784
785
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
786
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
787

788
    log.info(f"file.content_type: {file.content_type}")
789
    try:
790
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
791
        filename = os.path.basename(unsanitized_filename)
792

Timothy J. Baek's avatar
Timothy J. Baek committed
793
        file_path = f"{UPLOAD_DIR}/{filename}"
794

795
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
796
        with open(file_path, "wb") as f:
797
798
799
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
800
801
802
803
804
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
805
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
806
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
807
808
809
810
811
812
813
814
815
816
817
818

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
819
820
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
821
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
822
            )
823
    except Exception as e:
824
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
825
826
827
828
829
830
831
832
833
834
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
835
836


837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
853
854
855
856
857
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
858
859
860
861
862
863
864
865
866
867

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


868
869
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
870
871
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
872
873
874
875
876
877
878
879
880
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
881
882
883
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
884
885
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
910
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
911
912
913
914
915
916
917
918
919
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
920
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
921
                    pass
922

923
        except Exception as e:
924
            log.exception(e)
925
926
927
928

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
929
@app.get("/reset/db")
930
931
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
932
933
934


@app.get("/reset")
935
936
937
938
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
939
        try:
940
941
942
943
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
944
        except Exception as e:
945
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
946

947
948
949
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
950
        log.exception(e)
951
952

    return True