main.py 30.8 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
14
from typing import List, Union, Sequence
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
    YoutubeLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
)
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

36
37
38
39
40
import validators
import urllib.parse
import socket


41
42
from pydantic import BaseModel
from typing import Optional
43
import mimetypes
44
import uuid
45
46
import json

47
import sentence_transformers
48

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
49
from apps.webui.models.documents import (
50
51
52
53
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
54

55
from apps.rag.utils import (
56
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
61
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
62
    search_web,
63
)
Timothy J. Baek's avatar
Timothy J. Baek committed
64

65
66
67
68
69
70
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
71
from utils.utils import get_current_user, get_admin_user
72

73
from config import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
74
    ENV,
75
    SRC_LOG_LEVELS,
76
77
    UPLOAD_DIR,
    DOCS_DIR,
78
79
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
80
    RAG_EMBEDDING_ENGINE,
81
    RAG_EMBEDDING_MODEL,
82
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
83
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
84
    ENABLE_RAG_HYBRID_SEARCH,
85
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
86
    RAG_RERANKING_MODEL,
87
    PDF_EXTRACT_IMAGES,
88
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
89
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
92
    DEVICE_TYPE,
93
94
95
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
96
    RAG_TEMPLATE,
97
    ENABLE_RAG_LOCAL_WEB_FETCH,
98
    YOUTUBE_LOADER_LANGUAGE,
99
    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
100
    AppConfig,
101
)
102

103
104
from constants import ERROR_MESSAGES

105
106
107
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
108
109
app = FastAPI()

110
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
111

112
113
114
115
116
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
117
118
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
119

120
121
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
122

123
124
125
126
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
127

128

129
130
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
131

132
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
133

Steven Kreitzer's avatar
Steven Kreitzer committed
134

135
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
136
137
138
app.state.YOUTUBE_LOADER_TRANSLATION = None


139
140
141
142
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
143
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
168
    app.state.config.RAG_EMBEDDING_MODEL,
169
170
171
172
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
173
    app.state.config.RAG_RERANKING_MODEL,
174
175
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
176

Timothy J. Baek's avatar
Timothy J. Baek committed
177
178

app.state.EMBEDDING_FUNCTION = get_embedding_function(
179
180
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
181
    app.state.sentence_transformer_ef,
182
183
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
184
185
)

Timothy J. Baek's avatar
Timothy J. Baek committed
186
187
origins = ["*"]

188

Timothy J. Baek's avatar
Timothy J. Baek committed
189
190
191
192
193
194
195
196
197
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
198
class CollectionNameForm(BaseModel):
199
200
201
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
202
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
203
204
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
205

206
207
208
209
class SearchForm(CollectionNameForm):
    query: str


Timothy J. Baek's avatar
Timothy J. Baek committed
210
211
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
212
213
    return {
        "status": True,
214
215
216
217
218
219
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
220
221
222
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
223
224
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
225
226
    return {
        "status": True,
227
228
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
229
        "openai_config": {
230
231
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
232
        },
233
234
235
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
236
237
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
238
239
    return {
        "status": True,
240
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
241
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
242
243


244
245
246
247
248
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


249
class EmbeddingModelUpdateForm(BaseModel):
250
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
251
    embedding_engine: str
252
253
254
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
255
256
@app.post("/embedding/update")
async def update_embedding_config(
257
258
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
259
    log.info(
260
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
261
    )
262
    try:
263
264
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
265

266
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
267
            if form_data.openai_config != None:
268
269
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
270

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
271
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
272

Timothy J. Baek's avatar
Timothy J. Baek committed
273
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
274
275
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
276
            app.state.sentence_transformer_ef,
277
278
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
279
280
        )

281
282
        return {
            "status": True,
283
284
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
285
            "openai_config": {
286
287
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
288
            },
289
290
291
292
293
294
295
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
296
297


Steven Kreitzer's avatar
Steven Kreitzer committed
298
299
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
300

Steven Kreitzer's avatar
Steven Kreitzer committed
301
302
303
304
305
306

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
307
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
308
309
    )
    try:
310
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
311

312
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
313
314
315

        return {
            "status": True,
316
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
317
318
319
320
321
322
323
324
325
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
326
327
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
328
329
    return {
        "status": True,
330
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
331
        "chunk": {
332
333
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
334
        },
335
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
336
        "youtube": {
337
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
338
339
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
340
341
342
343
344
345
346
347
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


348
349
350
351
352
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
353
class ConfigUpdateForm(BaseModel):
354
355
356
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
    web_loader_ssl_verification: Optional[bool] = None
357
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
358
359
360
361


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
362
    app.state.config.PDF_EXTRACT_IMAGES = (
363
        form_data.pdf_extract_images
364
365
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
366
367
    )

368
369
370
371
    app.state.config.CHUNK_SIZE = (
        form_data.chunk.chunk_size
        if form_data.chunk is not None
        else app.state.config.CHUNK_SIZE
372
373
    )

374
    app.state.config.CHUNK_OVERLAP = (
375
        form_data.chunk.chunk_overlap
376
377
        if form_data.chunk is not None
        else app.state.config.CHUNK_OVERLAP
378
379
    )

380
    app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
381
382
        form_data.web_loader_ssl_verification
        if form_data.web_loader_ssl_verification != None
383
        else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
384
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
385

386
    app.state.config.YOUTUBE_LOADER_LANGUAGE = (
387
        form_data.youtube.language
388
389
        if form_data.youtube is not None
        else app.state.config.YOUTUBE_LOADER_LANGUAGE
390
391
392
393
    )

    app.state.YOUTUBE_LOADER_TRANSLATION = (
        form_data.youtube.translation
394
        if form_data.youtube is not None
395
396
397
        else app.state.YOUTUBE_LOADER_TRANSLATION
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
398
399
    return {
        "status": True,
400
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
401
        "chunk": {
402
403
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
404
        },
405
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
406
        "youtube": {
407
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
408
409
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
410
    }
411
412


Timothy J. Baek's avatar
Timothy J. Baek committed
413
414
415
416
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
417
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
418
419
420
    }


421
422
423
424
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
425
426
427
428
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
429
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
430
431


432
433
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
434
    r: Optional[float] = None
435
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
436
    hybrid: Optional[bool] = None
437
438
439
440
441
442


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
443
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
444
        form_data.template if form_data.template else RAG_TEMPLATE
445
    )
446
447
448
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
449
        form_data.hybrid if form_data.hybrid else False
450
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
451
452
    return {
        "status": True,
453
454
455
456
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
457
    }
458
459


460
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
461
462
    collection_name: str
    query: str
463
    k: Optional[int] = None
464
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
465
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
466
467


468
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
469
def query_doc_handler(
470
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
471
472
    user=Depends(get_current_user),
):
473
    try:
474
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
475
476
477
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
478
                embedding_function=app.state.EMBEDDING_FUNCTION,
479
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
480
                reranking_function=app.state.sentence_transformer_rf,
481
                r=(
482
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
483
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
484
485
486
487
488
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
489
                embedding_function=app.state.EMBEDDING_FUNCTION,
490
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
491
            )
492
    except Exception as e:
493
        log.exception(e)
494
495
496
497
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
498
499


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
500
501
502
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
503
    k: Optional[int] = None
504
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
505
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
506
507


508
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
509
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
510
511
512
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
513
    try:
514
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
515
516
517
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
518
                embedding_function=app.state.EMBEDDING_FUNCTION,
519
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
520
                reranking_function=app.state.sentence_transformer_rf,
521
                r=(
522
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
523
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
524
525
526
527
528
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
529
                embedding_function=app.state.EMBEDDING_FUNCTION,
530
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
531
            )
532

533
534
535
536
537
538
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
539
540


Timothy J. Baek's avatar
Timothy J. Baek committed
541
542
543
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
544
545
546
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
547
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
548
549
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


570
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
571
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
572
573
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
574
        loader = get_web_loader(
575
            form_data.url,
576
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
577
        )
578
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
579
580
581
582
583

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

584
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
585
586
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
587
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
588
589
            "filename": form_data.url,
        }
590
    except Exception as e:
591
        log.exception(e)
592
593
594
595
596
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )

597

598
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
599
    # Check if the URL is valid
600
    if not validate_url(url):
601
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
602
603
604
605
    return WebBaseLoader(
        url,
        verify_ssl=verify_ssl,
        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
606
        continue_on_failure=True,
607
    )
608
609


610
611
612
613
def validate_url(url: Union[str, Sequence[str]]):
    if isinstance(url, str):
        if isinstance(validators.url(url), validators.ValidationError):
            raise ValueError(ERROR_MESSAGES.INVALID_URL)
614
        if not ENABLE_RAG_LOCAL_WEB_FETCH:
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
            parsed_url = urllib.parse.urlparse(url)
            # Get IPv4 and IPv6 addresses
            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
            # Check if any of the resolved addresses are private
            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
            for ip in ipv4_addresses:
                if validators.ipv4(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
            for ip in ipv6_addresses:
                if validators.ipv6(ip, private=True):
                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
        return True
    elif isinstance(url, Sequence):
        return all(validate_url(u) for u in url)
    else:
        return False


634
635
636
637
638
639
640
641
642
643
644
def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
645
646
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
647
    try:
648
649
650
651
652
653
654
655
        try:
            web_results = search_web(form_data.query)
        except Exception as e:
            log.exception(e)
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
            )
656
657
        urls = [result.link for result in web_results]
        loader = get_web_loader(urls)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
658
        data = loader.load()
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.query)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filenames": urls,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


678
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
679

680
    text_splitter = RecursiveCharacterTextSplitter(
681
682
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
683
684
        add_start_index=True,
    )
685

686
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
687
688

    if len(docs) > 0:
689
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
690
691
692
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
693
694
695


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
696
    text, metadata, collection_name, overwrite: bool = False
697
698
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
699
700
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
701
702
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
703
    docs = text_splitter.create_documents([text], metadatas=[metadata])
704
705
706
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
707
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
708
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
709

710
711
712
713
714
715
716
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
717
                    log.info(f"deleting existing collection {collection_name}")
718
719
                    CHROMA_CLIENT.delete_collection(name=collection_name)

720
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
721

Timothy J. Baek's avatar
Timothy J. Baek committed
722
        embedding_func = get_embedding_function(
723
724
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
725
            app.state.sentence_transformer_ef,
726
727
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Steven Kreitzer's avatar
Steven Kreitzer committed
728
729
730
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
731
        embeddings = embedding_func(embedding_texts)
732
733
734

        for batch in create_batches(
            api=CHROMA_CLIENT,
735
            ids=[str(uuid.uuid4()) for _ in texts],
736
737
738
739
740
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
741

742
        return True
743
    except Exception as e:
744
        log.exception(e)
745
746
747
748
749
750
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


751
752
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
801
        loader = PyPDFLoader(
802
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
803
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
804
805
806
807
808
809
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
810
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
811
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
812
813
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
814
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
815
816
        loader = UnstructuredEPubLoader(file_path)
    elif (
817
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
818
819
820
821
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
822
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
823
824
825
826
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
827
828
829
830
831
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
832
833
834
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
835
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
836
    else:
837
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
838
839
840
841
842
        known_type = False

    return loader, known_type


843
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
844
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
845
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
846
847
848
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
849
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
850

851
    log.info(f"file.content_type: {file.content_type}")
852
    try:
853
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
854
        filename = os.path.basename(unsanitized_filename)
855

Timothy J. Baek's avatar
Timothy J. Baek committed
856
        file_path = f"{UPLOAD_DIR}/{filename}"
857

858
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
859
        with open(file_path, "wb") as f:
860
861
862
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
863
864
865
866
867
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
868
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
869
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
870
871
872
873
874
875
876
877
878
879
880
881

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
882
883
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
884
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
885
            )
886
    except Exception as e:
887
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
888
889
890
891
892
893
894
895
896
897
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
898
899


900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
916
917
918
919
920
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
921
922
923
924
925
926
927
928
929
930

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


931
932
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
933
934
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
935
936
937
938
939
940
941
942
943
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
944
945
946
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
947
948
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
973
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
974
975
976
977
978
979
980
981
982
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
983
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
984
                    pass
985

986
        except Exception as e:
987
            log.exception(e)
988
989
990
991

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
992
@app.get("/reset/db")
993
994
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
995
996
997


@app.get("/reset")
998
999
1000
1001
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
1002
        try:
1003
1004
1005
1006
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
1007
        except Exception as e:
1008
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
1009

1010
1011
1012
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
1013
        log.exception(e)
1014
1015

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026


if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}