main.py 22.6 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
)
32
33
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

from pydantic import BaseModel
from typing import Optional
36
import mimetypes
37
import uuid
38
39
import json

40
import sentence_transformers
41

42
43
44
45
46
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
47

48
from apps.rag.utils import (
49
    get_model_path,
50
    query_embeddings_doc,
Steven Kreitzer's avatar
Steven Kreitzer committed
51
    query_embeddings_function,
52
53
    query_embeddings_collection,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
54

55
56
57
58
59
60
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
61
from utils.utils import get_current_user, get_admin_user
62

63
from config import (
64
    SRC_LOG_LEVELS,
65
66
    UPLOAD_DIR,
    DOCS_DIR,
67
68
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
69
    RAG_EMBEDDING_ENGINE,
70
    RAG_EMBEDDING_MODEL,
71
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
72
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Steven Kreitzer's avatar
Steven Kreitzer committed
73
    RAG_HYBRID,
Steven Kreitzer's avatar
Steven Kreitzer committed
74
    RAG_RERANKING_MODEL,
75
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
76
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
79
    DEVICE_TYPE,
80
81
82
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
83
    RAG_TEMPLATE,
84
)
85

86
87
from constants import ERROR_MESSAGES

88
89
90
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
91
92
app = FastAPI()

93
94
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
Steven Kreitzer's avatar
Steven Kreitzer committed
95
96
app.state.HYBRID = RAG_HYBRID

Timothy J. Baek's avatar
Timothy J. Baek committed
97
98
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
99

100
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
101
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
Steven Kreitzer's avatar
Steven Kreitzer committed
102
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
103
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
104

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
105
106
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
107

108
109
app.state.PDF_EXTRACT_IMAGES = False

Steven Kreitzer's avatar
Steven Kreitzer committed
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
    if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
    app.state.RAG_EMBEDDING_MODEL,
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
    app.state.RAG_RERANKING_MODEL,
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
148

Timothy J. Baek's avatar
Timothy J. Baek committed
149
150
origins = ["*"]

151

Timothy J. Baek's avatar
Timothy J. Baek committed
152
153
154
155
156
157
158
159
160
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
161
class CollectionNameForm(BaseModel):
162
163
164
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
165
166
167
class StoreWebForm(CollectionNameForm):
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
168

Timothy J. Baek's avatar
Timothy J. Baek committed
169
170
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
171
172
173
174
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
175
        "template": app.state.RAG_TEMPLATE,
176
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
177
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
178
        "reranking_model": app.state.RAG_RERANKING_MODEL,
179
180
181
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
182
183
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
184
185
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
186
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
187
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
188
        "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
189
190
            "url": app.state.OPENAI_API_BASE_URL,
            "key": app.state.OPENAI_API_KEY,
191
        },
192
193
194
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
195
196
197
198
199
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
    return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}


200
201
202
203
204
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


205
class EmbeddingModelUpdateForm(BaseModel):
206
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
207
    embedding_engine: str
208
209
210
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
211
212
@app.post("/embedding/update")
async def update_embedding_config(
213
214
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
215
216
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
217
    )
218
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
219
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
220
        app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
221

222
223
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
            if form_data.openai_config != None:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
224
225
                app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.OPENAI_API_KEY = form_data.openai_config.key
226

227
        update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
228

229
230
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
231
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
232
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
233
            "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
234
235
                "url": app.state.OPENAI_API_BASE_URL,
                "key": app.state.OPENAI_API_KEY,
236
            },
237
238
239
240
241
242
243
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
244
245


Steven Kreitzer's avatar
Steven Kreitzer committed
246
247
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
248

Steven Kreitzer's avatar
Steven Kreitzer committed
249
250
251
252
253
254
255
256
257
258

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
        f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
    )
    try:
        app.state.RAG_RERANKING_MODEL = form_data.reranking_model
259

260
        update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
Steven Kreitzer's avatar
Steven Kreitzer committed
261
262
263
264
265
266
267
268
269
270
271
272
273

        return {
            "status": True,
            "reranking_model": app.state.RAG_RERANKING_MODEL,
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
274
275
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
276
277
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
278
279
280
281
282
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
283
284
285
286
287
288
289
290
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


Timothy J. Baek's avatar
Timothy J. Baek committed
291
292
293
294
295
296
297
298
299
300
class ConfigUpdateForm(BaseModel):
    pdf_extract_images: bool
    chunk: ChunkParamUpdateForm


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
Timothy J. Baek's avatar
Timothy J. Baek committed
301
302
303

    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
304
305
306
307
308
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
309
    }
310
311


Timothy J. Baek's avatar
Timothy J. Baek committed
312
313
314
315
316
317
318
319
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


320
321
322
323
324
325
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
326
        "r": app.state.RELEVANCE_THRESHOLD,
Steven Kreitzer's avatar
Steven Kreitzer committed
327
        "hybrid": app.state.HYBRID,
328
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
329
330


331
332
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
333
    r: Optional[float] = None
334
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
335
    hybrid: Optional[bool] = None
336
337
338
339
340
341
342
343


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
344
    app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
Steven Kreitzer's avatar
Steven Kreitzer committed
345
346
347
348
349
350
351
352
    app.state.HYBRID = form_data.hybrid if form_data.hybrid else False
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
        "r": app.state.RELEVANCE_THRESHOLD,
        "hybrid": app.state.HYBRID,
    }
353
354


355
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
356
357
    collection_name: str
    query: str
358
    k: Optional[int] = None
359
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
360
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
361
362


363
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
364
def query_doc_handler(
365
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
366
367
    user=Depends(get_current_user),
):
368
    try:
Steven Kreitzer's avatar
Steven Kreitzer committed
369
370
371
372
373
374
375
        embeddings_function = query_embeddings_function(
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )
376

377
378
379
380
        return query_embeddings_doc(
            collection_name=form_data.collection_name,
            query=form_data.query,
            k=form_data.k if form_data.k else app.state.TOP_K,
381
            r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
Steven Kreitzer's avatar
Steven Kreitzer committed
382
383
            embeddings_function=embeddings_function,
            reranking_function=app.state.sentence_transformer_rf,
Steven Kreitzer's avatar
Steven Kreitzer committed
384
            hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID,
385
        )
386
    except Exception as e:
387
        log.exception(e)
388
389
390
391
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
392
393


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
394
395
396
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
397
    k: Optional[int] = None
398
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
399
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
400
401


402
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
403
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
404
405
406
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
407
    try:
408
        embeddings_function = query_embeddings_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
409
410
411
412
413
414
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )
415

416
417
        return query_embeddings_collection(
            collection_names=form_data.collection_names,
Steven Kreitzer's avatar
Steven Kreitzer committed
418
            query=form_data.query,
419
            k=form_data.k if form_data.k else app.state.TOP_K,
420
            r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
Steven Kreitzer's avatar
Steven Kreitzer committed
421
422
            embeddings_function=embeddings_function,
            reranking_function=app.state.sentence_transformer_rf,
Steven Kreitzer's avatar
Steven Kreitzer committed
423
            hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID,
424
        )
425
426
427
428
429
430
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
431
432


433
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
434
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
435
436
437
438
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
        loader = WebBaseLoader(form_data.url)
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
439
440
441
442
443

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

444
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
445
446
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
447
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
448
449
            "filename": form_data.url,
        }
450
    except Exception as e:
451
        log.exception(e)
452
453
454
455
456
457
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


458
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
459

460
461
462
463
464
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
465

466
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
467
468

    if len(docs) > 0:
469
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
470
471
472
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
473
474
475


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
476
    text, metadata, collection_name, overwrite: bool = False
477
478
479
480
481
482
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
483
    docs = text_splitter.create_documents([text], metadatas=[metadata])
484
485
486
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
487
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
488
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
489

490
491
492
493
494
495
496
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
497
                    log.info(f"deleting existing collection {collection_name}")
498
499
                    CHROMA_CLIENT.delete_collection(name=collection_name)

500
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
501

Steven Kreitzer's avatar
Steven Kreitzer committed
502
503
504
505
506
507
508
509
510
        embedding_func = query_embeddings_function(
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
511
        embeddings = embedding_func(embedding_texts)
512
513
514
515
516
517
518
519
520

        for batch in create_batches(
            api=CHROMA_CLIENT,
            ids=[str(uuid.uuid1()) for _ in texts],
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
521

522
        return True
523
    except Exception as e:
524
        log.exception(e)
525
526
527
528
529
530
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


531
532
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
581
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
582
583
584
585
586
587
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
588
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
589
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
590
591
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
592
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
593
594
        loader = UnstructuredEPubLoader(file_path)
    elif (
595
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
596
597
598
599
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
600
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
601
602
603
604
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
605
606
607
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
608
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
609
    else:
610
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
611
612
613
614
615
        known_type = False

    return loader, known_type


616
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
617
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
618
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
619
620
621
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
622
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
623

624
    log.info(f"file.content_type: {file.content_type}")
625
    try:
626
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
627
        filename = os.path.basename(unsanitized_filename)
628

Timothy J. Baek's avatar
Timothy J. Baek committed
629
        file_path = f"{UPLOAD_DIR}/{filename}"
630

631
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
632
        with open(file_path, "wb") as f:
633
634
635
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
636
637
638
639
640
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
641
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
642
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
643
644
645
646
647
648
649
650
651
652
653
654

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
655
656
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
657
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
658
            )
659
    except Exception as e:
660
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
661
662
663
664
665
666
667
668
669
670
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
671
672


673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
689
690
691
692
693
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
694
695
696
697
698
699
700
701
702
703

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


704
705
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
706
707
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
708
709
710
711
712
713
714
715
716
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
717
718
719
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
720
721
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
746
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
747
748
749
750
751
752
753
754
755
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
756
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
757
                    pass
758

759
        except Exception as e:
760
            log.exception(e)
761
762
763
764

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
765
@app.get("/reset/db")
766
767
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
768
769
770


@app.get("/reset")
771
772
773
774
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
775
        try:
776
777
778
779
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
780
        except Exception as e:
781
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
782

783
784
785
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
786
        log.exception(e)
787
788

    return True