main.py 30.5 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
14
from typing import List, Union, Sequence
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    YoutubeLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
)
33
34
from langchain.text_splitter import RecursiveCharacterTextSplitter

35
36
37
38
39
import validators
import urllib.parse
import socket


40
41
from pydantic import BaseModel
from typing import Optional
42
import mimetypes
43
import uuid
44
45
import json

46
import sentence_transformers
47

48
49
50
51
52
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
53

54
from apps.rag.utils import (
55
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57
58
59
60
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
61
    search_web,
62
)
Timothy J. Baek's avatar
Timothy J. Baek committed
63

64
65
66
67
68
69
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
70
from utils.utils import get_current_user, get_admin_user
71

72
from config import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
73
    ENV,
74
    SRC_LOG_LEVELS,
75
76
    UPLOAD_DIR,
    DOCS_DIR,
77
78
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
79
    RAG_EMBEDDING_ENGINE,
80
    RAG_EMBEDDING_MODEL,
81
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
82
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
83
    ENABLE_RAG_HYBRID_SEARCH,
84
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
85
    RAG_RERANKING_MODEL,
86
    PDF_EXTRACT_IMAGES,
87
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
88
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
91
    DEVICE_TYPE,
92
93
94
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
95
    RAG_TEMPLATE,
96
    ENABLE_RAG_LOCAL_WEB_FETCH,
97
    YOUTUBE_LOADER_LANGUAGE,
98
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
99
    AppConfig,
100
)
101

102
103
from constants import ERROR_MESSAGES

104
105
106
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
107
108
app = FastAPI()

109
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
110

111
112
113
114
115
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
116
117
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
118

119
120
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
121

122
123
124
125
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
126

127

128
129
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
130

131
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
132

Steven Kreitzer's avatar
Steven Kreitzer committed
133

134
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
135
136
137
app.state.YOUTUBE_LOADER_TRANSLATION = None


138
139
140
141
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
142
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
167
    app.state.config.RAG_EMBEDDING_MODEL,
168
169
170
171
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
172
    app.state.config.RAG_RERANKING_MODEL,
173
174
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
175

Timothy J. Baek's avatar
Timothy J. Baek committed
176
177

app.state.EMBEDDING_FUNCTION = get_embedding_function(
178
179
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
180
    app.state.sentence_transformer_ef,
181
182
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
)

Timothy J. Baek's avatar
Timothy J. Baek committed
185
186
origins = ["*"]

187

Timothy J. Baek's avatar
Timothy J. Baek committed
188
189
190
191
192
193
194
195
196
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
197
class CollectionNameForm(BaseModel):
198
199
200
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
201
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
202
203
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
204

205
206
207
208
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
209
210
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
211
212
    return {
        "status": True,
213
214
215
216
217
218
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
219
220
221
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
222
223
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
224
225
    return {
        "status": True,
226
227
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
228
        "openai_config": {
229
230
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
231
        },
232
233
234
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
235
236
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
237
238
    return {
        "status": True,
239
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
240
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
241
242


243
244
245
246
247
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


248
class EmbeddingModelUpdateForm(BaseModel):
249
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
250
    embedding_engine: str
251
252
253
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
254
255
@app.post("/embedding/update")
async def update_embedding_config(
256
257
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
258
    log.info(
259
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
260
    )
261
    try:
262
263
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
264

265
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
266
            if form_data.openai_config != None:
267
268
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
269

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
270
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
271

Timothy J. Baek's avatar
Timothy J. Baek committed
272
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
273
274
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
275
            app.state.sentence_transformer_ef,
276
277
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
278
279
        )

280
281
        return {
            "status": True,
282
283
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
284
            "openai_config": {
285
286
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
287
            },
288
289
290
291
292
293
294
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
295
296


Steven Kreitzer's avatar
Steven Kreitzer committed
297
298
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
299

Steven Kreitzer's avatar
Steven Kreitzer committed
300
301
302
303
304
305

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
306
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
307
308
    )
    try:
309
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
310

311
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
312
313
314

        return {
            "status": True,
315
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
316
317
318
319
320
321
322
323
324
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
325
326
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
327
328
    return {
        "status": True,
329
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
330
        "chunk": {
331
332
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
333
        },
334
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
335
        "youtube": {
336
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
337
338
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
339
340
341
342
343
344
345
346
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


347
348
349
350
351
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
352
class ConfigUpdateForm(BaseModel):
353
354
355
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
    web_loader_ssl_verification: Optional[bool] = None
356
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
357
358
359
360


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
361
    app.state.config.PDF_EXTRACT_IMAGES = (
362
        form_data.pdf_extract_images
363
364
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
365
366
    )

367
368
369
370
    app.state.config.CHUNK_SIZE = (
        form_data.chunk.chunk_size
        if form_data.chunk is not None
        else app.state.config.CHUNK_SIZE
371
372
    )

373
    app.state.config.CHUNK_OVERLAP = (
374
        form_data.chunk.chunk_overlap
375
376
        if form_data.chunk is not None
        else app.state.config.CHUNK_OVERLAP
377
378
    )

379
    app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
380
381
        form_data.web_loader_ssl_verification
        if form_data.web_loader_ssl_verification != None
382
        else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
383
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
384

385
    app.state.config.YOUTUBE_LOADER_LANGUAGE = (
386
        form_data.youtube.language
387
388
        if form_data.youtube is not None
        else app.state.config.YOUTUBE_LOADER_LANGUAGE
389
390
391
392
    )

    app.state.YOUTUBE_LOADER_TRANSLATION = (
        form_data.youtube.translation
393
        if form_data.youtube is not None
394
395
396
        else app.state.YOUTUBE_LOADER_TRANSLATION
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
397
398
    return {
        "status": True,
399
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
400
        "chunk": {
401
402
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
403
        },
404
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
405
        "youtube": {
406
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
407
408
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
409
    }
410
411


Timothy J. Baek's avatar
Timothy J. Baek committed
412
413
414
415
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
416
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
417
418
419
    }


420
421
422
423
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
424
425
426
427
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
428
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
429
430


431
432
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
433
    r: Optional[float] = None
434
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
435
    hybrid: Optional[bool] = None
436
437
438
439
440
441


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
442
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
443
        form_data.template if form_data.template else RAG_TEMPLATE
444
    )
445
446
447
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
448
        form_data.hybrid if form_data.hybrid else False
449
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
450
451
    return {
        "status": True,
452
453
454
455
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
456
    }
457
458


459
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
460
461
    collection_name: str
    query: str
462
    k: Optional[int] = None
463
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
464
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
465
466


467
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
468
def query_doc_handler(
469
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
470
471
    user=Depends(get_current_user),
):
472
    try:
473
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
474
475
476
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
477
                embedding_function=app.state.EMBEDDING_FUNCTION,
478
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
479
                reranking_function=app.state.sentence_transformer_rf,
480
                r=(
481
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
482
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
483
484
485
486
487
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
488
                embedding_function=app.state.EMBEDDING_FUNCTION,
489
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
490
            )
491
    except Exception as e:
492
        log.exception(e)
493
494
495
496
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
497
498


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
499
500
501
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
502
    k: Optional[int] = None
503
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
504
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
505
506


507
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
508
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
509
510
511
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
512
    try:
513
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
514
515
516
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
517
                embedding_function=app.state.EMBEDDING_FUNCTION,
518
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
519
                reranking_function=app.state.sentence_transformer_rf,
520
                r=(
521
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
522
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
523
524
525
526
527
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
528
                embedding_function=app.state.EMBEDDING_FUNCTION,
529
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
530
            )
531

532
533
534
535
536
537
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
538
539


Timothy J. Baek's avatar
Timothy J. Baek committed
540
541
542
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
543
544
545
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
546
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
547
548
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


569
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
570
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
571
572
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
573
        loader = get_web_loader(
574
            form_data.url,
575
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
576
        )
577
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
578
579
580
581
582

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

583
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
584
585
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
586
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
587
588
            "filename": form_data.url,
        }
589
    except Exception as e:
590
        log.exception(e)
591
592
593
594
595
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

596

597
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
598
    # Check if the URL is valid
599
    if not validate_url(url):
600
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
601
602
603
604
    return WebBaseLoader(
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
605
        continue_on_failure=True,
606
    )
607
608


609
610
611
612
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
613
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False


633
634
635
636
637
638
639
640
641
642
643
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


644
645
646
@app.post("/websearch")
def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
    try:
647
648
649
650
651
652
653
654
        try:
            web_results = search_web(form_data.query)
        except Exception as e:
            log.exception(e)
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
            )
655
656
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
657
        data = loader.aload()
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


677
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
678

679
    text_splitter = RecursiveCharacterTextSplitter(
680
681
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
682
683
        add_start_index=True,
    )
684

685
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
686
687

    if len(docs) > 0:
688
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
689
690
691
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
692
693
694


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
695
    text, metadata, collection_name, overwrite: bool = False
696
697
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
698
699
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
700
701
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
702
    docs = text_splitter.create_documents([text], metadatas=[metadata])
703
704
705
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
706
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
707
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
708

709
710
711
712
713
714
715
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
716
                    log.info(f"deleting existing collection {collection_name}")
717
718
                    CHROMA_CLIENT.delete_collection(name=collection_name)

719
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
720

Timothy J. Baek's avatar
Timothy J. Baek committed
721
        embedding_func = get_embedding_function(
722
723
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
724
            app.state.sentence_transformer_ef,
725
726
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Steven Kreitzer's avatar
Steven Kreitzer committed
727
728
729
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
730
        embeddings = embedding_func(embedding_texts)
731
732
733

        for batch in create_batches(
            api=CHROMA_CLIENT,
734
            ids=[str(uuid.uuid4()) for _ in texts],
735
736
737
738
739
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
740

741
        return True
742
    except Exception as e:
743
        log.exception(e)
744
745
746
747
748
749
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


750
751
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
800
        loader = PyPDFLoader(
801
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
802
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
803
804
805
806
807
808
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
809
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
810
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
811
812
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
813
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
814
815
        loader = UnstructuredEPubLoader(file_path)
    elif (
816
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
817
818
819
820
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
821
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
822
823
824
825
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
826
827
828
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
829
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
830
    else:
831
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
832
833
834
835
836
        known_type = False

    return loader, known_type


837
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
838
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
839
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
840
841
842
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
843
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
844

845
    log.info(f"file.content_type: {file.content_type}")
846
    try:
847
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
848
        filename = os.path.basename(unsanitized_filename)
849

Timothy J. Baek's avatar
Timothy J. Baek committed
850
        file_path = f"{UPLOAD_DIR}/{filename}"
851

852
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
853
        with open(file_path, "wb") as f:
854
855
856
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
857
858
859
860
861
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
862
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
863
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
864
865
866
867
868
869
870
871
872
873
874
875

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
876
877
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
878
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
879
            )
880
    except Exception as e:
881
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
882
883
884
885
886
887
888
889
890
891
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
892
893


894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
910
911
912
913
914
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
915
916
917
918
919
920
921
922
923
924

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


925
926
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
927
928
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
929
930
931
932
933
934
935
936
937
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
938
939
940
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
941
942
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
967
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
968
969
970
971
972
973
974
975
976
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
977
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
978
                    pass
979

980
        except Exception as e:
981
            log.exception(e)
982
983
984
985

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
986
@app.get("/reset/db")
987
988
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
989
990
991


@app.get("/reset")
992
993
994
995
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
996
        try:
997
998
999
1000
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1001
        except Exception as e:
1002
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1003

1004
1005
1006
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1007
        log.exception(e)
1008
1009

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020


if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}