main.py 22.8 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
)
32
33
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

from pydantic import BaseModel
from typing import Optional
36
import mimetypes
37
import uuid
38
39
import json

40
import sentence_transformers
41

42
43
44
45
46
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
47

48
from apps.rag.utils import (
49
    get_model_path,
50
    query_embeddings_doc,
Timothy J. Baek's avatar
Timothy J. Baek committed
51
    get_embeddings_function,
52
53
    query_embeddings_collection,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
54

55
56
57
58
59
60
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
61
from utils.utils import get_current_user, get_admin_user
62

63
from config import (
64
    SRC_LOG_LEVELS,
65
66
    UPLOAD_DIR,
    DOCS_DIR,
67
68
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
69
    RAG_EMBEDDING_ENGINE,
70
    RAG_EMBEDDING_MODEL,
71
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
72
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
73
    ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
74
    RAG_RERANKING_MODEL,
75
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
76
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
79
    DEVICE_TYPE,
80
81
82
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
83
    RAG_TEMPLATE,
84
)
85

86
87
from constants import ERROR_MESSAGES

88
89
90
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
91
92
app = FastAPI()

93
94
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
Timothy J. Baek's avatar
Timothy J. Baek committed
95
96

app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
Steven Kreitzer's avatar
Steven Kreitzer committed
97

Timothy J. Baek's avatar
Timothy J. Baek committed
98
99
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
100

101
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
102
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
Steven Kreitzer's avatar
Steven Kreitzer committed
103
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
104
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
105

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
106
107
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
108

109
110
app.state.PDF_EXTRACT_IMAGES = False

Steven Kreitzer's avatar
Steven Kreitzer committed
111

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
    if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
    app.state.RAG_EMBEDDING_MODEL,
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
    app.state.RAG_RERANKING_MODEL,
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
149

Timothy J. Baek's avatar
Timothy J. Baek committed
150
151
origins = ["*"]

152

Timothy J. Baek's avatar
Timothy J. Baek committed
153
154
155
156
157
158
159
160
161
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
162
class CollectionNameForm(BaseModel):
163
164
165
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
168
class StoreWebForm(CollectionNameForm):
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
169

Timothy J. Baek's avatar
Timothy J. Baek committed
170
171
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
172
173
174
175
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
176
        "template": app.state.RAG_TEMPLATE,
177
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
178
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
179
        "reranking_model": app.state.RAG_RERANKING_MODEL,
180
181
182
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
185
186
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
187
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
188
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
189
        "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
190
191
            "url": app.state.OPENAI_API_BASE_URL,
            "key": app.state.OPENAI_API_KEY,
192
        },
193
194
195
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
196
197
198
199
200
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
    return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}


201
202
203
204
205
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


206
class EmbeddingModelUpdateForm(BaseModel):
207
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
208
    embedding_engine: str
209
210
211
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
212
213
@app.post("/embedding/update")
async def update_embedding_config(
214
215
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
216
217
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
218
    )
219
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
220
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
221
        app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
222

223
224
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
            if form_data.openai_config != None:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
225
226
                app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.OPENAI_API_KEY = form_data.openai_config.key
227

228
        update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
229

230
231
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
232
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
233
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
234
            "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
235
236
                "url": app.state.OPENAI_API_BASE_URL,
                "key": app.state.OPENAI_API_KEY,
237
            },
238
239
240
241
242
243
244
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
245
246


Steven Kreitzer's avatar
Steven Kreitzer committed
247
248
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
249

Steven Kreitzer's avatar
Steven Kreitzer committed
250
251
252
253
254
255
256
257
258
259

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
        f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
    )
    try:
        app.state.RAG_RERANKING_MODEL = form_data.reranking_model
260

261
        update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
Steven Kreitzer's avatar
Steven Kreitzer committed
262
263
264
265
266
267
268
269
270
271
272
273
274

        return {
            "status": True,
            "reranking_model": app.state.RAG_RERANKING_MODEL,
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
275
276
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
277
278
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
279
280
281
282
283
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
284
285
286
287
288
289
290
291
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


Timothy J. Baek's avatar
Timothy J. Baek committed
292
293
294
295
296
297
298
299
300
301
class ConfigUpdateForm(BaseModel):
    pdf_extract_images: bool
    chunk: ChunkParamUpdateForm


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
Timothy J. Baek's avatar
Timothy J. Baek committed
302
303
304

    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
305
306
307
308
309
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
310
    }
311
312


Timothy J. Baek's avatar
Timothy J. Baek committed
313
314
315
316
317
318
319
320
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


321
322
323
324
325
326
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
327
        "r": app.state.RELEVANCE_THRESHOLD,
Timothy J. Baek's avatar
Timothy J. Baek committed
328
        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
329
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
330
331


332
333
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
334
    r: Optional[float] = None
335
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
336
    hybrid: Optional[bool] = None
337
338
339
340
341
342
343
344


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
345
    app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
Timothy J. Baek's avatar
Timothy J. Baek committed
346
    app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
Steven Kreitzer's avatar
Steven Kreitzer committed
347
348
349
350
351
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
        "r": app.state.RELEVANCE_THRESHOLD,
Timothy J. Baek's avatar
Timothy J. Baek committed
352
        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
353
    }
354
355


356
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
357
358
    collection_name: str
    query: str
359
    k: Optional[int] = None
360
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
361
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
362
363


364
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
365
def query_doc_handler(
366
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
367
368
    user=Depends(get_current_user),
):
369
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
370
        embeddings_function = get_embeddings_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
371
372
373
374
375
376
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )
377

378
379
380
381
        return query_embeddings_doc(
            collection_name=form_data.collection_name,
            query=form_data.query,
            k=form_data.k if form_data.k else app.state.TOP_K,
382
            r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
Steven Kreitzer's avatar
Steven Kreitzer committed
383
384
            embeddings_function=embeddings_function,
            reranking_function=app.state.sentence_transformer_rf,
Timothy J. Baek's avatar
Timothy J. Baek committed
385
386
387
388
389
            hybrid_search=(
                form_data.hybrid
                if form_data.hybrid
                else app.state.ENABLE_RAG_HYBRID_SEARCH
            ),
390
        )
391
    except Exception as e:
392
        log.exception(e)
393
394
395
396
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
397
398


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
399
400
401
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
402
    k: Optional[int] = None
403
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
404
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
405
406


407
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
408
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
409
410
411
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
412
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
413
        embeddings_function = get_embeddings_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
414
415
416
417
418
419
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )
420

421
422
        return query_embeddings_collection(
            collection_names=form_data.collection_names,
Steven Kreitzer's avatar
Steven Kreitzer committed
423
            query=form_data.query,
424
            k=form_data.k if form_data.k else app.state.TOP_K,
425
            r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
Steven Kreitzer's avatar
Steven Kreitzer committed
426
427
            embeddings_function=embeddings_function,
            reranking_function=app.state.sentence_transformer_rf,
Timothy J. Baek's avatar
Timothy J. Baek committed
428
429
430
431
432
            hybrid_search=(
                form_data.hybrid
                if form_data.hybrid
                else app.state.ENABLE_RAG_HYBRID_SEARCH
            ),
433
        )
434
435
436
437
438
439
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
440
441


442
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
443
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
444
445
446
447
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
        loader = WebBaseLoader(form_data.url)
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
448
449
450
451
452

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

453
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
454
455
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
456
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
457
458
            "filename": form_data.url,
        }
459
    except Exception as e:
460
        log.exception(e)
461
462
463
464
465
466
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


467
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
468

469
470
471
472
473
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
474

475
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
476
477

    if len(docs) > 0:
478
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
479
480
481
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
482
483
484


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
485
    text, metadata, collection_name, overwrite: bool = False
486
487
488
489
490
491
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
492
    docs = text_splitter.create_documents([text], metadatas=[metadata])
493
494
495
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
496
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
497
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
498

499
500
501
502
503
504
505
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
506
                    log.info(f"deleting existing collection {collection_name}")
507
508
                    CHROMA_CLIENT.delete_collection(name=collection_name)

509
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
510

Timothy J. Baek's avatar
Timothy J. Baek committed
511
        embedding_func = get_embeddings_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
512
513
514
515
516
517
518
519
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
520
        embeddings = embedding_func(embedding_texts)
521
522
523
524
525
526
527
528
529

        for batch in create_batches(
            api=CHROMA_CLIENT,
            ids=[str(uuid.uuid1()) for _ in texts],
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
530

531
        return True
532
    except Exception as e:
533
        log.exception(e)
534
535
536
537
538
539
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


540
541
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
590
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
591
592
593
594
595
596
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
597
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
598
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
599
600
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
601
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
602
603
        loader = UnstructuredEPubLoader(file_path)
    elif (
604
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
605
606
607
608
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
609
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
610
611
612
613
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
614
615
616
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
617
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
618
    else:
619
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
620
621
622
623
624
        known_type = False

    return loader, known_type


625
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
626
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
627
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
628
629
630
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
631
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
632

633
    log.info(f"file.content_type: {file.content_type}")
634
    try:
635
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
636
        filename = os.path.basename(unsanitized_filename)
637

Timothy J. Baek's avatar
Timothy J. Baek committed
638
        file_path = f"{UPLOAD_DIR}/{filename}"
639

640
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
641
        with open(file_path, "wb") as f:
642
643
644
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
645
646
647
648
649
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
650
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
651
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
652
653
654
655
656
657
658
659
660
661
662
663

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
664
665
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
666
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
667
            )
668
    except Exception as e:
669
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
670
671
672
673
674
675
676
677
678
679
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
680
681


682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
698
699
700
701
702
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
703
704
705
706
707
708
709
710
711
712

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


713
714
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
715
716
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
717
718
719
720
721
722
723
724
725
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
726
727
728
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
729
730
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
755
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
756
757
758
759
760
761
762
763
764
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
765
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
766
                    pass
767

768
        except Exception as e:
769
            log.exception(e)
770
771
772
773

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
774
@app.get("/reset/db")
775
776
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
777
778
779


@app.get("/reset")
780
781
782
783
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
784
        try:
785
786
787
788
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
789
        except Exception as e:
790
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
791

792
793
794
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
795
        log.exception(e)
796
797

    return True