main.py 27.5 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    YoutubeLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
)
33
34
from langchain.text_splitter import RecursiveCharacterTextSplitter

35
36
37
38
39
import validators
import urllib.parse
import socket


40
41
from pydantic import BaseModel
from typing import Optional
42
import mimetypes
43
import uuid
44
45
import json

46
import sentence_transformers
47

48
49
50
51
52
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
53

54
from apps.rag.utils import (
55
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57
58
59
60
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
61
)
Timothy J. Baek's avatar
Timothy J. Baek committed
62

63
64
65
66
67
68
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
69
from utils.utils import get_current_user, get_admin_user
70

71
from config import (
72
    SRC_LOG_LEVELS,
73
74
    UPLOAD_DIR,
    DOCS_DIR,
75
76
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
77
    RAG_EMBEDDING_ENGINE,
78
    RAG_EMBEDDING_MODEL,
79
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
80
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
81
    ENABLE_RAG_HYBRID_SEARCH,
82
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
83
    RAG_RERANKING_MODEL,
84
    PDF_EXTRACT_IMAGES,
85
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
86
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
87
88
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
89
    DEVICE_TYPE,
90
91
92
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
93
    RAG_TEMPLATE,
94
    ENABLE_RAG_LOCAL_WEB_FETCH,
95
)
96

97
98
from constants import ERROR_MESSAGES

99
100
101
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
102
103
app = FastAPI()

104
105
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
Timothy J. Baek's avatar
Timothy J. Baek committed
106
107

app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
108
109
110
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
111

Timothy J. Baek's avatar
Timothy J. Baek committed
112
113
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
114

115
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
116
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
Steven Kreitzer's avatar
Steven Kreitzer committed
117
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
118
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
119

120

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
121
122
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
123

124
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
125

Steven Kreitzer's avatar
Steven Kreitzer committed
126

127
128
129
130
app.state.YOUTUBE_LOADER_LANGUAGE = ["en"]
app.state.YOUTUBE_LOADER_TRANSLATION = None


131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
    if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
    app.state.RAG_EMBEDDING_MODEL,
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
    app.state.RAG_RERANKING_MODEL,
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
168

Timothy J. Baek's avatar
Timothy J. Baek committed
169
170
171
172
173
174
175
176
177

app.state.EMBEDDING_FUNCTION = get_embedding_function(
    app.state.RAG_EMBEDDING_ENGINE,
    app.state.RAG_EMBEDDING_MODEL,
    app.state.sentence_transformer_ef,
    app.state.OPENAI_API_KEY,
    app.state.OPENAI_API_BASE_URL,
)

Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
origins = ["*"]

180

Timothy J. Baek's avatar
Timothy J. Baek committed
181
182
183
184
185
186
187
188
189
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
190
class CollectionNameForm(BaseModel):
191
192
193
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
194
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
195
196
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
197

Timothy J. Baek's avatar
Timothy J. Baek committed
198
199
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
200
201
202
203
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
204
        "template": app.state.RAG_TEMPLATE,
205
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
206
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
207
        "reranking_model": app.state.RAG_RERANKING_MODEL,
208
209
210
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
211
212
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
213
214
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
215
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
216
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
217
        "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
218
219
            "url": app.state.OPENAI_API_BASE_URL,
            "key": app.state.OPENAI_API_KEY,
220
        },
221
222
223
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
224
225
226
227
228
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
    return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}


229
230
231
232
233
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


234
class EmbeddingModelUpdateForm(BaseModel):
235
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
236
    embedding_engine: str
237
238
239
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
240
241
@app.post("/embedding/update")
async def update_embedding_config(
242
243
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
244
245
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
246
    )
247
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
248
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
249
        app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
250

251
252
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
            if form_data.openai_config != None:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
253
254
                app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.OPENAI_API_KEY = form_data.openai_config.key
255

256
        update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
257

Timothy J. Baek's avatar
Timothy J. Baek committed
258
259
260
261
262
263
264
265
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )

266
267
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
268
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
269
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
270
            "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
271
272
                "url": app.state.OPENAI_API_BASE_URL,
                "key": app.state.OPENAI_API_KEY,
273
            },
274
275
276
277
278
279
280
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
281
282


Steven Kreitzer's avatar
Steven Kreitzer committed
283
284
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
285

Steven Kreitzer's avatar
Steven Kreitzer committed
286
287
288
289
290
291
292
293
294
295

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
        f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
    )
    try:
        app.state.RAG_RERANKING_MODEL = form_data.reranking_model
296

297
        update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
Steven Kreitzer's avatar
Steven Kreitzer committed
298
299
300
301
302
303
304
305
306
307
308
309
310

        return {
            "status": True,
            "reranking_model": app.state.RAG_RERANKING_MODEL,
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
311
312
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
313
314
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
315
316
317
318
319
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
320
        "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
321
322
323
324
        "youtube": {
            "language": app.state.YOUTUBE_LOADER_LANGUAGE,
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
325
326
327
328
329
330
331
332
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


333
334
335
336
337
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
338
class ConfigUpdateForm(BaseModel):
339
340
341
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
    web_loader_ssl_verification: Optional[bool] = None
342
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
343
344
345
346


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    app.state.PDF_EXTRACT_IMAGES = (
        form_data.pdf_extract_images
        if form_data.pdf_extract_images != None
        else app.state.PDF_EXTRACT_IMAGES
    )

    app.state.CHUNK_SIZE = (
        form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
    )

    app.state.CHUNK_OVERLAP = (
        form_data.chunk.chunk_overlap
        if form_data.chunk != None
        else app.state.CHUNK_OVERLAP
    )

    app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
        form_data.web_loader_ssl_verification
        if form_data.web_loader_ssl_verification != None
        else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
368

369
370
371
372
373
374
375
376
377
378
379
380
    app.state.YOUTUBE_LOADER_LANGUAGE = (
        form_data.youtube.language
        if form_data.youtube != None
        else app.state.YOUTUBE_LOADER_LANGUAGE
    )

    app.state.YOUTUBE_LOADER_TRANSLATION = (
        form_data.youtube.translation
        if form_data.youtube != None
        else app.state.YOUTUBE_LOADER_TRANSLATION
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
381
382
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
383
384
385
386
387
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
388
        "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
389
390
391
392
        "youtube": {
            "language": app.state.YOUTUBE_LOADER_LANGUAGE,
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
393
    }
394
395


Timothy J. Baek's avatar
Timothy J. Baek committed
396
397
398
399
400
401
402
403
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


404
405
406
407
408
409
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
410
        "r": app.state.RELEVANCE_THRESHOLD,
Timothy J. Baek's avatar
Timothy J. Baek committed
411
        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
412
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
413
414


415
416
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
417
    r: Optional[float] = None
418
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
419
    hybrid: Optional[bool] = None
420
421
422
423
424
425
426
427


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
428
    app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
Timothy J. Baek's avatar
Timothy J. Baek committed
429
    app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
Steven Kreitzer's avatar
Steven Kreitzer committed
430
431
432
433
434
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
        "r": app.state.RELEVANCE_THRESHOLD,
Timothy J. Baek's avatar
Timothy J. Baek committed
435
        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
436
    }
437
438


439
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
440
441
    collection_name: str
    query: str
442
    k: Optional[int] = None
443
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
444
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
445
446


447
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
448
def query_doc_handler(
449
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
450
451
    user=Depends(get_current_user),
):
452
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
453
454
455
456
        if app.state.ENABLE_RAG_HYBRID_SEARCH:
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
457
                embedding_function=app.state.EMBEDDING_FUNCTION,
Timothy J. Baek's avatar
Timothy J. Baek committed
458
                k=form_data.k if form_data.k else app.state.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
459
                reranking_function=app.state.sentence_transformer_rf,
Timothy J. Baek's avatar
Timothy J. Baek committed
460
461
462
463
464
465
                r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
466
                embedding_function=app.state.EMBEDDING_FUNCTION,
Timothy J. Baek's avatar
Timothy J. Baek committed
467
468
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
469
    except Exception as e:
470
        log.exception(e)
471
472
473
474
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
475
476


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
477
478
479
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
480
    k: Optional[int] = None
481
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
482
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
483
484


485
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
486
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
487
488
489
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
490
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
491
492
493
494
        if app.state.ENABLE_RAG_HYBRID_SEARCH:
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
495
                embedding_function=app.state.EMBEDDING_FUNCTION,
Timothy J. Baek's avatar
Timothy J. Baek committed
496
                k=form_data.k if form_data.k else app.state.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
497
                reranking_function=app.state.sentence_transformer_rf,
Timothy J. Baek's avatar
Timothy J. Baek committed
498
499
500
501
502
503
                r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
504
                embedding_function=app.state.EMBEDDING_FUNCTION,
Timothy J. Baek's avatar
Timothy J. Baek committed
505
506
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
507

508
509
510
511
512
513
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
514
515


Timothy J. Baek's avatar
Timothy J. Baek committed
516
517
518
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
519
520
521
522
523
524
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
            language=app.state.YOUTUBE_LOADER_LANGUAGE,
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


545
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
546
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
547
548
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
549
550
551
        loader = get_web_loader(
            form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
        )
552
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
553
554
555
556
557

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

558
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
559
560
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
561
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
562
563
            "filename": form_data.url,
        }
564
    except Exception as e:
565
        log.exception(e)
566
567
568
569
570
571
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


572
def get_web_loader(url: str, verify_ssl: bool = True):
573
574
575
    # Check if the URL is valid
    if isinstance(validators.url(url), validators.ValidationError):
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
576
    if not ENABLE_RAG_LOCAL_WEB_FETCH:
577
578
579
580
581
582
583
584
585
586
587
588
        # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
        parsed_url = urllib.parse.urlparse(url)
        # Get IPv4 and IPv6 addresses
        ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
        # Check if any of the resolved addresses are private
        # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
        for ip in ipv4_addresses:
            if validators.ipv4(ip, private=True):
                raise ValueError(ERROR_MESSAGES.INVALID_URL)
        for ip in ipv6_addresses:
            if validators.ipv6(ip, private=True):
                raise ValueError(ERROR_MESSAGES.INVALID_URL)
589
    return WebBaseLoader(url, verify_ssl=verify_ssl)
590
591
592
593
594
595
596
597
598
599
600
601
602


def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


603
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
604

605
606
607
608
609
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
610

611
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
612
613

    if len(docs) > 0:
614
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
615
616
617
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
618
619
620


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
621
    text, metadata, collection_name, overwrite: bool = False
622
623
624
625
626
627
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
628
    docs = text_splitter.create_documents([text], metadatas=[metadata])
629
630
631
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
632
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
633
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
634

635
636
637
638
639
640
641
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
642
                    log.info(f"deleting existing collection {collection_name}")
643
644
                    CHROMA_CLIENT.delete_collection(name=collection_name)

645
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
646

Timothy J. Baek's avatar
Timothy J. Baek committed
647
        embedding_func = get_embedding_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
648
649
650
651
652
653
654
655
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
656
        embeddings = embedding_func(embedding_texts)
657
658
659

        for batch in create_batches(
            api=CHROMA_CLIENT,
660
            ids=[str(uuid.uuid4()) for _ in texts],
661
662
663
664
665
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
666

667
        return True
668
    except Exception as e:
669
        log.exception(e)
670
671
672
673
674
675
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


676
677
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
726
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
727
728
729
730
731
732
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
733
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
734
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
735
736
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
737
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
738
739
        loader = UnstructuredEPubLoader(file_path)
    elif (
740
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
741
742
743
744
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
745
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
746
747
748
749
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
750
751
752
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
753
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
754
    else:
755
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
756
757
758
759
760
        known_type = False

    return loader, known_type


761
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
762
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
763
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
764
765
766
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
767
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
768

769
    log.info(f"file.content_type: {file.content_type}")
770
    try:
771
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
772
        filename = os.path.basename(unsanitized_filename)
773

Timothy J. Baek's avatar
Timothy J. Baek committed
774
        file_path = f"{UPLOAD_DIR}/{filename}"
775

776
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
777
        with open(file_path, "wb") as f:
778
779
780
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
781
782
783
784
785
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
786
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
787
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
788
789
790
791
792
793
794
795
796
797
798
799

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
800
801
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
802
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
803
            )
804
    except Exception as e:
805
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
806
807
808
809
810
811
812
813
814
815
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
816
817


818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
834
835
836
837
838
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
839
840
841
842
843
844
845
846
847
848

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


849
850
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
851
852
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
853
854
855
856
857
858
859
860
861
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
862
863
864
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
865
866
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
891
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
892
893
894
895
896
897
898
899
900
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
901
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
902
                    pass
903

904
        except Exception as e:
905
            log.exception(e)
906
907
908
909

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
910
@app.get("/reset/db")
911
912
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
913
914
915


@app.get("/reset")
916
917
918
919
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
920
        try:
921
922
923
924
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
925
        except Exception as e:
926
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
927

928
929
930
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
931
        log.exception(e)
932
933

    return True