main.py 29.1 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
    UnstructuredPowerPointLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
    YoutubeLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
)
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

36
37
38
39
40
import validators
import urllib.parse
import socket


41
42
from pydantic import BaseModel
from typing import Optional
43
import mimetypes
44
import uuid
45
46
import json

47
import sentence_transformers
48

49
50
51
52
53
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
54

55
from apps.rag.utils import (
56
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
61
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
62
)
Timothy J. Baek's avatar
Timothy J. Baek committed
63

64
65
66
67
68
69
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
70
from utils.utils import get_current_user, get_admin_user
71

72
from config import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
73
    ENV,
74
    SRC_LOG_LEVELS,
75
76
    UPLOAD_DIR,
    DOCS_DIR,
77
78
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
79
    RAG_EMBEDDING_ENGINE,
80
    RAG_EMBEDDING_MODEL,
81
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
82
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
83
    ENABLE_RAG_HYBRID_SEARCH,
84
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
Steven Kreitzer's avatar
Steven Kreitzer committed
85
    RAG_RERANKING_MODEL,
86
    PDF_EXTRACT_IMAGES,
87
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
88
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
91
    DEVICE_TYPE,
92
93
94
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
95
    RAG_TEMPLATE,
96
    ENABLE_RAG_LOCAL_WEB_FETCH,
97
    YOUTUBE_LOADER_LANGUAGE,
98
    AppConfig,
99
)
100

101
102
from constants import ERROR_MESSAGES

103
104
105
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
106
107
app = FastAPI()

108
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
109

110
111
112
113
114
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD

app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
115
116
    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
Steven Kreitzer's avatar
Steven Kreitzer committed
117

118
119
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
120

121
122
123
124
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
125

126

127
128
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
129

130
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
131

Steven Kreitzer's avatar
Steven Kreitzer committed
132

133
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
134
135
136
app.state.YOUTUBE_LOADER_TRANSLATION = None


137
138
139
140
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
141
    if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
166
    app.state.config.RAG_EMBEDDING_MODEL,
167
168
169
170
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
171
    app.state.config.RAG_RERANKING_MODEL,
172
173
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
174

Timothy J. Baek's avatar
Timothy J. Baek committed
175
176

app.state.EMBEDDING_FUNCTION = get_embedding_function(
177
178
    app.state.config.RAG_EMBEDDING_ENGINE,
    app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
179
    app.state.sentence_transformer_ef,
180
181
    app.state.config.OPENAI_API_KEY,
    app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
182
183
)

Timothy J. Baek's avatar
Timothy J. Baek committed
184
185
origins = ["*"]

186

Timothy J. Baek's avatar
Timothy J. Baek committed
187
188
189
190
191
192
193
194
195
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
196
class CollectionNameForm(BaseModel):
197
198
199
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
200
class UrlForm(CollectionNameForm):
Timothy J. Baek's avatar
Timothy J. Baek committed
201
202
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
203

Timothy J. Baek's avatar
Timothy J. Baek committed
204
205
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
206
207
    return {
        "status": True,
208
209
210
211
212
213
        "chunk_size": app.state.config.CHUNK_SIZE,
        "chunk_overlap": app.state.config.CHUNK_OVERLAP,
        "template": app.state.config.RAG_TEMPLATE,
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
214
215
216
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
217
218
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
219
220
    return {
        "status": True,
221
222
        "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
        "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
223
        "openai_config": {
224
225
            "url": app.state.config.OPENAI_API_BASE_URL,
            "key": app.state.config.OPENAI_API_KEY,
226
        },
227
228
229
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
230
231
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
232
233
    return {
        "status": True,
234
        "reranking_model": app.state.config.RAG_RERANKING_MODEL,
235
    }
Steven Kreitzer's avatar
Steven Kreitzer committed
236
237


238
239
240
241
242
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


243
class EmbeddingModelUpdateForm(BaseModel):
244
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
245
    embedding_engine: str
246
247
248
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
249
250
@app.post("/embedding/update")
async def update_embedding_config(
251
252
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
253
    log.info(
254
        f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
255
    )
256
    try:
257
258
        app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
        app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
259

260
        if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
261
            if form_data.openai_config != None:
262
263
                app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.config.OPENAI_API_KEY = form_data.openai_config.key
264

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
265
        update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
266

Timothy J. Baek's avatar
Timothy J. Baek committed
267
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
268
269
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
270
            app.state.sentence_transformer_ef,
271
272
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
273
274
        )

275
276
        return {
            "status": True,
277
278
            "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
            "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
279
            "openai_config": {
280
281
                "url": app.state.config.OPENAI_API_BASE_URL,
                "key": app.state.config.OPENAI_API_KEY,
282
            },
283
284
285
286
287
288
289
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
290
291


Steven Kreitzer's avatar
Steven Kreitzer committed
292
293
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
294

Steven Kreitzer's avatar
Steven Kreitzer committed
295
296
297
298
299
300

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
301
        f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
Steven Kreitzer's avatar
Steven Kreitzer committed
302
303
    )
    try:
304
        app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
305

306
        update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
Steven Kreitzer's avatar
Steven Kreitzer committed
307
308
309

        return {
            "status": True,
310
            "reranking_model": app.state.config.RAG_RERANKING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
311
312
313
314
315
316
317
318
319
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
320
321
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
322
323
    return {
        "status": True,
324
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
325
        "chunk": {
326
327
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
328
        },
329
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
330
        "youtube": {
331
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
332
333
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
334
335
336
337
338
339
340
341
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


342
343
344
345
346
class YoutubeLoaderConfig(BaseModel):
    language: List[str]
    translation: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
347
class ConfigUpdateForm(BaseModel):
348
349
350
    pdf_extract_images: Optional[bool] = None
    chunk: Optional[ChunkParamUpdateForm] = None
    web_loader_ssl_verification: Optional[bool] = None
351
    youtube: Optional[YoutubeLoaderConfig] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
352
353
354
355


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
356
357
358
359
    app.state.config.PDF_EXTRACT_IMAGES = (
        form_data.pdf_extract_images
        if form_data.pdf_extract_images is not None
        else app.state.config.PDF_EXTRACT_IMAGES
360
361
    )

362
363
364
365
    app.state.config.CHUNK_SIZE = (
        form_data.chunk.chunk_size
        if form_data.chunk is not None
        else app.state.config.CHUNK_SIZE
366
367
    )

368
369
370
371
    app.state.config.CHUNK_OVERLAP = (
        form_data.chunk.chunk_overlap
        if form_data.chunk is not None
        else app.state.config.CHUNK_OVERLAP
372
373
    )

374
375
376
377
    app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
        form_data.web_loader_ssl_verification
        if form_data.web_loader_ssl_verification != None
        else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
378
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
379

380
381
382
383
    app.state.config.YOUTUBE_LOADER_LANGUAGE = (
        form_data.youtube.language
        if form_data.youtube is not None
        else app.state.config.YOUTUBE_LOADER_LANGUAGE
384
385
386
387
    )

    app.state.YOUTUBE_LOADER_TRANSLATION = (
        form_data.youtube.translation
388
        if form_data.youtube is not None
389
390
391
        else app.state.YOUTUBE_LOADER_TRANSLATION
    )

Timothy J. Baek's avatar
Timothy J. Baek committed
392
393
    return {
        "status": True,
394
        "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
Timothy J. Baek's avatar
Timothy J. Baek committed
395
        "chunk": {
396
397
            "chunk_size": app.state.config.CHUNK_SIZE,
            "chunk_overlap": app.state.config.CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
398
        },
399
        "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
400
        "youtube": {
401
            "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
402
403
            "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
404
    }
405
406


Timothy J. Baek's avatar
Timothy J. Baek committed
407
408
409
410
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
411
        "template": app.state.config.RAG_TEMPLATE,
Timothy J. Baek's avatar
Timothy J. Baek committed
412
413
414
    }


415
416
417
418
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
419
420
421
422
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
423
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
424
425


426
427
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
428
    r: Optional[float] = None
429
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
430
    hybrid: Optional[bool] = None
431
432
433
434
435
436


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
437
    app.state.config.RAG_TEMPLATE = (
Timothy J. Baek's avatar
Timothy J. Baek committed
438
        form_data.template if form_data.template else RAG_TEMPLATE
439
    )
440
441
442
    app.state.config.TOP_K = form_data.k if form_data.k else 4
    app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
    app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
Timothy J. Baek's avatar
Timothy J. Baek committed
443
        form_data.hybrid if form_data.hybrid else False
444
    )
Steven Kreitzer's avatar
Steven Kreitzer committed
445
446
    return {
        "status": True,
447
448
449
450
        "template": app.state.config.RAG_TEMPLATE,
        "k": app.state.config.TOP_K,
        "r": app.state.config.RELEVANCE_THRESHOLD,
        "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
451
    }
452
453


454
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
455
456
    collection_name: str
    query: str
457
    k: Optional[int] = None
458
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
459
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
460
461


462
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
463
def query_doc_handler(
464
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
465
466
    user=Depends(get_current_user),
):
467
    try:
468
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
469
470
471
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
472
                embedding_function=app.state.EMBEDDING_FUNCTION,
473
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
474
                reranking_function=app.state.sentence_transformer_rf,
475
                r=(
476
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
477
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
478
479
480
481
482
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
483
                embedding_function=app.state.EMBEDDING_FUNCTION,
484
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
485
            )
486
    except Exception as e:
487
        log.exception(e)
488
489
490
491
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
492
493


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
494
495
496
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
497
    k: Optional[int] = None
498
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
499
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
500
501


502
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
503
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
504
505
506
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
507
    try:
508
        if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
Timothy J. Baek's avatar
Timothy J. Baek committed
509
510
511
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
512
                embedding_function=app.state.EMBEDDING_FUNCTION,
513
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
514
                reranking_function=app.state.sentence_transformer_rf,
515
                r=(
516
                    form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
517
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
518
519
520
521
522
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
Steven Kreitzer's avatar
Steven Kreitzer committed
523
                embedding_function=app.state.EMBEDDING_FUNCTION,
524
                k=form_data.k if form_data.k else app.state.config.TOP_K,
Timothy J. Baek's avatar
Timothy J. Baek committed
525
            )
526

527
528
529
530
531
532
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
533
534


Timothy J. Baek's avatar
Timothy J. Baek committed
535
536
537
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
    try:
538
539
540
        loader = YoutubeLoader.from_youtube_url(
            form_data.url,
            add_video_info=True,
541
542
            language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
            translation=app.state.YOUTUBE_LOADER_TRANSLATION,
543
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        data = loader.load()

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

        store_data_in_vector_db(data, collection_name, overwrite=True)
        return {
            "status": True,
            "collection_name": collection_name,
            "filename": form_data.url,
        }
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


564
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
565
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
566
567
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
568
        loader = get_web_loader(
569
            form_data.url,
570
            verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
571
        )
572
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
573
574
575
576
577

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

578
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
579
580
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
581
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
582
583
            "filename": form_data.url,
        }
584
    except Exception as e:
585
        log.exception(e)
586
587
588
589
590
591
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


592
def get_web_loader(url: str, verify_ssl: bool = True):
593
594
595
    # Check if the URL is valid
    if isinstance(validators.url(url), validators.ValidationError):
        raise ValueError(ERROR_MESSAGES.INVALID_URL)
596
    if not ENABLE_RAG_LOCAL_WEB_FETCH:
597
598
599
600
601
602
603
604
605
606
607
608
        # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
        parsed_url = urllib.parse.urlparse(url)
        # Get IPv4 and IPv6 addresses
        ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
        # Check if any of the resolved addresses are private
        # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
        for ip in ipv4_addresses:
            if validators.ipv4(ip, private=True):
                raise ValueError(ERROR_MESSAGES.INVALID_URL)
        for ip in ipv6_addresses:
            if validators.ipv6(ip, private=True):
                raise ValueError(ERROR_MESSAGES.INVALID_URL)
609
    return WebBaseLoader(url, verify_ssl=verify_ssl)
610
611
612
613
614
615
616
617
618
619
620
621
622


def resolve_hostname(hostname):
    # Get address information
    addr_info = socket.getaddrinfo(hostname, None)

    # Extract IP addresses from address information
    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]

    return ipv4_addresses, ipv6_addresses


623
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
624

625
    text_splitter = RecursiveCharacterTextSplitter(
626
627
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
628
629
        add_start_index=True,
    )
630

631
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
632
633

    if len(docs) > 0:
634
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
635
636
637
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
638
639
640


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
641
    text, metadata, collection_name, overwrite: bool = False
642
643
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
644
645
        chunk_size=app.state.config.CHUNK_SIZE,
        chunk_overlap=app.state.config.CHUNK_OVERLAP,
646
647
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
648
    docs = text_splitter.create_documents([text], metadatas=[metadata])
649
650
651
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
652
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
653
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
654

655
656
657
658
659
660
661
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
662
                    log.info(f"deleting existing collection {collection_name}")
663
664
                    CHROMA_CLIENT.delete_collection(name=collection_name)

665
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
666

Timothy J. Baek's avatar
Timothy J. Baek committed
667
        embedding_func = get_embedding_function(
668
669
            app.state.config.RAG_EMBEDDING_ENGINE,
            app.state.config.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
670
            app.state.sentence_transformer_ef,
671
672
            app.state.config.OPENAI_API_KEY,
            app.state.config.OPENAI_API_BASE_URL,
Steven Kreitzer's avatar
Steven Kreitzer committed
673
674
675
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
676
        embeddings = embedding_func(embedding_texts)
677
678
679

        for batch in create_batches(
            api=CHROMA_CLIENT,
680
            ids=[str(uuid.uuid4()) for _ in texts],
681
682
683
684
685
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
686

687
        return True
688
    except Exception as e:
689
        log.exception(e)
690
691
692
693
694
695
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


696
697
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
746
        loader = PyPDFLoader(
747
            file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
748
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
749
750
751
752
753
754
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
755
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
756
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
757
758
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
759
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
760
761
        loader = UnstructuredEPubLoader(file_path)
    elif (
762
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
763
764
765
766
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
767
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
768
769
770
771
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
772
773
774
775
776
    elif file_content_type in [
        "application/vnd.ms-powerpoint",
        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
    ] or file_ext in ["ppt", "pptx"]:
        loader = UnstructuredPowerPointLoader(file_path)
777
778
779
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
780
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
781
    else:
782
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
783
784
785
786
787
        known_type = False

    return loader, known_type


788
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
789
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
790
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
791
792
793
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
794
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
795

796
    log.info(f"file.content_type: {file.content_type}")
797
    try:
798
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
799
        filename = os.path.basename(unsanitized_filename)
800

Timothy J. Baek's avatar
Timothy J. Baek committed
801
        file_path = f"{UPLOAD_DIR}/{filename}"
802

803
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
804
        with open(file_path, "wb") as f:
805
806
807
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
808
809
810
811
812
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
813
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
814
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
815
816
817
818
819
820
821
822
823
824
825
826

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
827
828
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
829
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
830
            )
831
    except Exception as e:
832
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
833
834
835
836
837
838
839
840
841
842
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
843
844


845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
861
862
863
864
865
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
866
867
868
869
870
871
872
873
874
875

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


876
877
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
878
879
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
880
881
882
883
884
885
886
887
888
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
889
890
891
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
892
893
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
918
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
919
920
921
922
923
924
925
926
927
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
928
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
929
                    pass
930

931
        except Exception as e:
932
            log.exception(e)
933
934
935
936

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
937
@app.get("/reset/db")
938
939
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
940
941
942


@app.get("/reset")
943
944
945
946
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
947
        try:
948
949
950
951
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
952
        except Exception as e:
953
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
954

955
956
957
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
958
        log.exception(e)
959
960

    return True
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
961
962
963
964
965
966
967
968
969
970
971


if ENV == "dev":

    @app.get("/ef")
    async def get_embeddings():
        return {"result": app.state.EMBEDDING_FUNCTION("hello world")}

    @app.get("/ef/{text}")
    async def get_embeddings_text(text: str):
        return {"result": app.state.EMBEDDING_FUNCTION(text)}