main.py 28.5 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    YoutubeLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
)
33
34
from langchain.text_splitter import RecursiveCharacterTextSplitter

35
36
37
38
39
import validators
import urllib.parse
import socket


40
41
from pydantic import BaseModel
from typing import Optional
42
import mimetypes
43
import uuid
44
45
import json

46
import sentence_transformers
47

48
49
50
51
52
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
53

54
from apps.rag.utils import (
55
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57
58
59
60
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
61
)
Timothy J. Baek's avatar
Timothy J. Baek committed
62

63
64
65
66
67
68
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
69
from utils.utils import get_current_user, get_admin_user
70

71
from config import (
72
    SRC_LOG_LEVELS,
73
74
    UPLOAD_DIR,
    DOCS_DIR,
75
76
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
77
    RAG_EMBEDDING_ENGINE,
78
    RAG_EMBEDDING_MODEL,
79
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
80
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
81
    ENABLE_RAG_HYBRID_SEARCH,
82
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
83
    RAG_RERANKING_MODEL,
84
    PDF_EXTRACT_IMAGES,
85
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
86
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
87
88
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
89
    DEVICE_TYPE,
90
91
92
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
93
    RAG_TEMPLATE,
94
    ENABLE_RAG_LOCAL_WEB_FETCH,
95
    YOUTUBE_LOADER_LANGUAGE,
96
    AppConfig,
97
)
98

99
100
from constants import ERROR_MESSAGES

101
102
103
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
app = FastAPI()

106
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
107

108
109
110
111
112
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
113
114
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
115

116
117
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
118

119
120
121
122
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
123

124

125
126
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
127

128
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
129

Steven Kreitzer's avatar
Steven Kreitzer committed
130

131
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
132
133
134
app.state.YOUTUBE_LOADER_TRANSLATION = None


135
136
137
138
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
139
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
164
    app.state.config.RAG_EMBEDDING_MODEL,
165
166
167
168
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
169
    app.state.config.RAG_RERANKING_MODEL,
170
171
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
172

Timothy J. Baek's avatar
Timothy J. Baek committed
173
174

app.state.EMBEDDING_FUNCTION = get_embedding_function(
175
176
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
177
    app.state.sentence_transformer_ef,
178
179
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
180
181
)

Timothy J. Baek's avatar
Timothy J. Baek committed
182
183
origins = ["*"]

184

Timothy J. Baek's avatar
Timothy J. Baek committed
185
186
187
188
189
190
191
192
193
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
194
class CollectionNameForm(BaseModel):
195
196
197
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
198
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
199
200
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
201

Timothy J. Baek's avatar
Timothy J. Baek committed
202
203
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
204
205
    return {
        "status": True,
206
207
208
209
210
211
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
212
213
214
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
215
216
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
217
218
    return {
        "status": True,
219
220
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
221
        "openai_config": {
222
223
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
224
        },
225
226
227
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
228
229
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
230
231
    return {
        "status": True,
232
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
233
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
234
235


236
237
238
239
240
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


241
class EmbeddingModelUpdateForm(BaseModel):
242
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
243
    embedding_engine: str
244
245
246
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
247
248
@app.post("/embedding/update")
async def update_embedding_config(
249
250
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
251
    log.info(
252
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
253
    )
254
    try:
255
256
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
257

258
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
259
            if form_data.openai_config != None:
260
261
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
262

263
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True
264

Timothy J. Baek's avatar
Timothy J. Baek committed
265
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
266
267
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
268
            app.state.sentence_transformer_ef,
269
270
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
271
272
        )

273
274
        return {
            "status": True,
275
276
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
277
            "openai_config": {
278
279
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
280
            },
281
282
283
284
285
286
287
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
288
289


Steven Kreitzer's avatar
Steven Kreitzer committed
290
291
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
292

Steven Kreitzer's avatar
Steven Kreitzer committed
293
294
295
296
297
298

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
299
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
300
301
    )
    try:
302
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
303

304
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
305
306
307

        return {
            "status": True,
308
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
309
310
311
312
313
314
315
316
317
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
318
319
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
320
321
    return {
        "status": True,
322
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
323
        "chunk": {
324
325
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
326
        },
327
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
328
        "youtube": {
329
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
330
331
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
332
333
334
335
336
337
338
339
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


340
341
342
343
344
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
345
class ConfigUpdateForm(BaseModel):
346
347
348
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
    web_loader_ssl_verification: Optional[bool] = None
349
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
350
351
352
353


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
354
355
356
357
    app.state.config.PDF_EXTRACT_IMAGES = (
        form_data.pdf_extract_images
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
358
359
    )

360
361
362
363
    app.state.config.CHUNK_SIZE = (
        form_data.chunk.chunk_size
        if form_data.chunk is not None
        else app.state.config.CHUNK_SIZE
364
365
    )

366
367
368
369
    app.state.config.CHUNK_OVERLAP = (
        form_data.chunk.chunk_overlap
        if form_data.chunk is not None
        else app.state.config.CHUNK_OVERLAP
370
371
    )

372
373
374
375
    app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
        form_data.web_loader_ssl_verification
        if form_data.web_loader_ssl_verification != None
        else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
376
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
377

378
379
380
381
    app.state.config.YOUTUBE_LOADER_LANGUAGE = (
        form_data.youtube.language
        if form_data.youtube is not None
        else app.state.config.YOUTUBE_LOADER_LANGUAGE
382
383
384
385
    )

    app.state.YOUTUBE_LOADER_TRANSLATION = (
        form_data.youtube.translation
386
        if form_data.youtube is not None
387
388
389
        else app.state.YOUTUBE_LOADER_TRANSLATION
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
390
391
    return {
        "status": True,
392
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
393
        "chunk": {
394
395
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
396
        },
397
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
398
        "youtube": {
399
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
400
401
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
402
    }
403
404


Timothy J. Baek's avatar
Timothy J. Baek committed
405
406
407
408
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
409
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
410
411
412
    }


413
414
415
416
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
417
418
419
420
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
421
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
422
423


424
425
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
426
    r: Optional[float] = None
427
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
428
    hybrid: Optional[bool] = None
429
430
431
432
433
434


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
435
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
436
        form_data.template if form_data.template else RAG_TEMPLATE
437
    )
438
439
440
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
441
        form_data.hybrid if form_data.hybrid else False
442
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
443
444
    return {
        "status": True,
445
446
447
448
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
449
    }
450
451


452
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
453
454
    collection_name: str
    query: str
455
    k: Optional[int] = None
456
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
457
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
458
459


460
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
461
def query_doc_handler(
462
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
463
464
    user=Depends(get_current_user),
):
465
    try:
466
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
467
468
469
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
470
                embedding_function=app.state.EMBEDDING_FUNCTION,
471
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
472
                reranking_function=app.state.sentence_transformer_rf,
473
                r=(
474
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
475
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
476
477
478
479
480
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
481
                embedding_function=app.state.EMBEDDING_FUNCTION,
482
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
483
            )
484
    except Exception as e:
485
        log.exception(e)
486
487
488
489
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
490
491


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
492
493
494
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
495
    k: Optional[int] = None
496
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
497
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
498
499


500
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
501
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
502
503
504
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
505
    try:
506
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
507
508
509
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
510
                embedding_function=app.state.EMBEDDING_FUNCTION,
511
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
512
                reranking_function=app.state.sentence_transformer_rf,
513
                r=(
514
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
515
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
516
517
518
519
520
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
521
                embedding_function=app.state.EMBEDDING_FUNCTION,
522
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
523
            )
524

525
526
527
528
529
530
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
531
532


Timothy J. Baek's avatar
Timothy J. Baek committed
533
534
535
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
536
537
538
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
539
540
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
541
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


562
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
563
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
564
565
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
566
        loader = get_web_loader(
567
            form_data.url,
568
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
569
        )
570
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
571
572
573
574
575

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

576
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
577
578
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
579
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
580
581
            "filename": form_data.url,
        }
582
    except Exception as e:
583
        log.exception(e)
584
585
586
587
588
589
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


590
def get_web_loader(url: str, verify_ssl: bool = True):
591
592
593
    # Check if the URL is valid
    if isinstance(validators.url(url), validators.ValidationError):
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
594
    if not ENABLE_RAG_LOCAL_WEB_FETCH:
595
596
597
598
599
600
601
602
603
604
605
606
        # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
        parsed_url = urllib.parse.urlparse(url)
        # Get IPv4 and IPv6 addresses
        ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
        # Check if any of the resolved addresses are private
        # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
        for ip in ipv4_addresses:
            if validators.ipv4(ip, private=True):
                raise ValueError(ERROR_MESSAGES.INVALID_URL)
        for ip in ipv6_addresses:
            if validators.ipv6(ip, private=True):
                raise ValueError(ERROR_MESSAGES.INVALID_URL)
607
    return WebBaseLoader(url, verify_ssl=verify_ssl)
608
609
610
611
612
613
614
615
616
617
618
619
620


def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


621
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
622

623
    text_splitter = RecursiveCharacterTextSplitter(
624
625
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
626
627
        add_start_index=True,
    )
628

629
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
630
631

    if len(docs) > 0:
632
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
633
634
635
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
636
637
638


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
639
    text, metadata, collection_name, overwrite: bool = False
640
641
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
642
643
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
644
645
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
646
    docs = text_splitter.create_documents([text], metadatas=[metadata])
647
648
649
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
650
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
651
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
652

653
654
655
656
657
658
659
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
660
                    log.info(f"deleting existing collection {collection_name}")
661
662
                    CHROMA_CLIENT.delete_collection(name=collection_name)

663
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
664

Timothy J. Baek's avatar
Timothy J. Baek committed
665
        embedding_func = get_embedding_function(
666
667
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
668
            app.state.sentence_transformer_ef,
669
670
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Steven Kreitzer's avatar
Steven Kreitzer committed
671
672
673
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
674
        embeddings = embedding_func(embedding_texts)
675
676
677

        for batch in create_batches(
            api=CHROMA_CLIENT,
678
            ids=[str(uuid.uuid4()) for _ in texts],
679
680
681
682
683
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
684

685
        return True
686
    except Exception as e:
687
        log.exception(e)
688
689
690
691
692
693
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


694
695
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
744
        loader = PyPDFLoader(
745
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
746
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
747
748
749
750
751
752
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
753
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
754
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
755
756
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
757
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
758
759
        loader = UnstructuredEPubLoader(file_path)
    elif (
760
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
761
762
763
764
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
765
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
766
767
768
769
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
770
771
772
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
773
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
774
    else:
775
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
776
777
778
779
780
        known_type = False

    return loader, known_type


781
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
782
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
783
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
784
785
786
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
787
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
788

789
    log.info(f"file.content_type: {file.content_type}")
790
    try:
791
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
792
        filename = os.path.basename(unsanitized_filename)
793

Timothy J. Baek's avatar
Timothy J. Baek committed
794
        file_path = f"{UPLOAD_DIR}/{filename}"
795

796
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
797
        with open(file_path, "wb") as f:
798
799
800
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
801
802
803
804
805
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
806
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
807
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
808
809
810
811
812
813
814
815
816
817
818
819

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
820
821
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
822
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
823
            )
824
    except Exception as e:
825
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
826
827
828
829
830
831
832
833
834
835
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
836
837


838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
854
855
856
857
858
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
859
860
861
862
863
864
865
866
867
868

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


869
870
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
871
872
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
873
874
875
876
877
878
879
880
881
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
882
883
884
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
885
886
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
911
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
912
913
914
915
916
917
918
919
920
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
921
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
922
                    pass
923

924
        except Exception as e:
925
            log.exception(e)
926
927
928
929

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
930
@app.get("/reset/db")
931
932
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
933
934
935


@app.get("/reset")
936
937
938
939
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
940
        try:
941
942
943
944
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
945
        except Exception as e:
946
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
947

948
949
950
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
951
        log.exception(e)
952
953

    return True