main.py 23.3 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
)
32
33
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

from pydantic import BaseModel
from typing import Optional
36
import mimetypes
37
import uuid
38
39
import json

40
import sentence_transformers
41

42
43
44
45
46
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
47

48
from apps.rag.utils import (
49
    get_model_path,
Timothy J. Baek's avatar
Timothy J. Baek committed
50
51
52
53
54
    get_embedding_function,
    query_doc,
    query_doc_with_hybrid_search,
    query_collection,
    query_collection_with_hybrid_search,
55
)
Timothy J. Baek's avatar
Timothy J. Baek committed
56

57
58
59
60
61
62
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
63
from utils.utils import get_current_user, get_admin_user
64

65
from config import (
66
    SRC_LOG_LEVELS,
67
68
    UPLOAD_DIR,
    DOCS_DIR,
69
70
    RAG_TOP_K,
    RAG_RELEVANCE_THRESHOLD,
71
    RAG_EMBEDDING_ENGINE,
72
    RAG_EMBEDDING_MODEL,
73
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
74
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
75
    ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
76
    RAG_RERANKING_MODEL,
77
    RAG_RERANKING_MODEL_AUTO_UPDATE,
Steven Kreitzer's avatar
Steven Kreitzer committed
78
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
79
80
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
81
    DEVICE_TYPE,
82
83
84
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
85
    RAG_TEMPLATE,
86
)
87

88
89
from constants import ERROR_MESSAGES

90
91
92
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
app = FastAPI()

95
96
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
Timothy J. Baek's avatar
Timothy J. Baek committed
97
98

app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
Steven Kreitzer's avatar
Steven Kreitzer committed
99

Timothy J. Baek's avatar
Timothy J. Baek committed
100
101
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
102

103
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
104
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
Steven Kreitzer's avatar
Steven Kreitzer committed
105
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
106
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
107

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
108
109
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
110

111
112
app.state.PDF_EXTRACT_IMAGES = False

Steven Kreitzer's avatar
Steven Kreitzer committed
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def update_embedding_model(
    embedding_model: str,
    update_model: bool = False,
):
    if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
        app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
            get_model_path(embedding_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_ef = None


def update_reranking_model(
    reranking_model: str,
    update_model: bool = False,
):
    if reranking_model:
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            get_model_path(reranking_model, update_model),
            device=DEVICE_TYPE,
            trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
        )
    else:
        app.state.sentence_transformer_rf = None


update_embedding_model(
    app.state.RAG_EMBEDDING_MODEL,
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)

update_reranking_model(
    app.state.RAG_RERANKING_MODEL,
    RAG_RERANKING_MODEL_AUTO_UPDATE,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
151

Timothy J. Baek's avatar
Timothy J. Baek committed
152
153
154
155
156
157
158
159
160

app.state.EMBEDDING_FUNCTION = get_embedding_function(
    app.state.RAG_EMBEDDING_ENGINE,
    app.state.RAG_EMBEDDING_MODEL,
    app.state.sentence_transformer_ef,
    app.state.OPENAI_API_KEY,
    app.state.OPENAI_API_BASE_URL,
)

Timothy J. Baek's avatar
Timothy J. Baek committed
161
162
origins = ["*"]

163

Timothy J. Baek's avatar
Timothy J. Baek committed
164
165
166
167
168
169
170
171
172
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
173
class CollectionNameForm(BaseModel):
174
175
176
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
177
178
179
class StoreWebForm(CollectionNameForm):
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
180

Timothy J. Baek's avatar
Timothy J. Baek committed
181
182
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
185
186
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
187
        "template": app.state.RAG_TEMPLATE,
188
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
189
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
190
        "reranking_model": app.state.RAG_RERANKING_MODEL,
191
192
193
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
194
195
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
196
197
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
198
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
199
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
200
        "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
201
202
            "url": app.state.OPENAI_API_BASE_URL,
            "key": app.state.OPENAI_API_KEY,
203
        },
204
205
206
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
207
208
209
210
211
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
    return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}


212
213
214
215
216
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


217
class EmbeddingModelUpdateForm(BaseModel):
218
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
219
    embedding_engine: str
220
221
222
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
223
224
@app.post("/embedding/update")
async def update_embedding_config(
225
226
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
227
228
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
229
    )
230
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
231
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
232
        app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
Timothy J. Baek's avatar
Timothy J. Baek committed
233

234
235
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
            if form_data.openai_config != None:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
236
237
                app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.OPENAI_API_KEY = form_data.openai_config.key
238

239
        update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
240

Timothy J. Baek's avatar
Timothy J. Baek committed
241
242
243
244
245
246
247
248
        app.state.EMBEDDING_FUNCTION = get_embedding_function(
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )

249
250
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
251
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
252
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
253
            "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
254
255
                "url": app.state.OPENAI_API_BASE_URL,
                "key": app.state.OPENAI_API_KEY,
256
            },
257
258
259
260
261
262
263
        }
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
264
265


Steven Kreitzer's avatar
Steven Kreitzer committed
266
267
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
268

Steven Kreitzer's avatar
Steven Kreitzer committed
269
270
271
272
273
274
275
276
277
278

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
        f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
    )
    try:
        app.state.RAG_RERANKING_MODEL = form_data.reranking_model
279

280
        update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
Steven Kreitzer's avatar
Steven Kreitzer committed
281
282
283
284
285
286
287
288
289
290
291
292
293

        return {
            "status": True,
            "reranking_model": app.state.RAG_RERANKING_MODEL,
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
294
295
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
296
297
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
298
299
300
301
302
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
303
304
305
306
307
308
309
310
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


Timothy J. Baek's avatar
Timothy J. Baek committed
311
312
313
314
315
316
317
318
319
320
class ConfigUpdateForm(BaseModel):
    pdf_extract_images: bool
    chunk: ChunkParamUpdateForm


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
Timothy J. Baek's avatar
Timothy J. Baek committed
321
322
323

    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
324
325
326
327
328
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
329
    }
330
331


Timothy J. Baek's avatar
Timothy J. Baek committed
332
333
334
335
336
337
338
339
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


340
341
342
343
344
345
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
346
        "r": app.state.RELEVANCE_THRESHOLD,
Timothy J. Baek's avatar
Timothy J. Baek committed
347
        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
348
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
349
350


351
352
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
353
    r: Optional[float] = None
354
    template: Optional[str] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
355
    hybrid: Optional[bool] = None
356
357
358
359
360
361
362
363


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
364
    app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
Timothy J. Baek's avatar
Timothy J. Baek committed
365
    app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
Steven Kreitzer's avatar
Steven Kreitzer committed
366
367
368
369
370
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
        "r": app.state.RELEVANCE_THRESHOLD,
Timothy J. Baek's avatar
Timothy J. Baek committed
371
        "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
Steven Kreitzer's avatar
Steven Kreitzer committed
372
    }
373
374


375
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
376
377
    collection_name: str
    query: str
378
    k: Optional[int] = None
379
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
380
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
381
382


383
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
384
def query_doc_handler(
385
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
386
387
    user=Depends(get_current_user),
):
388
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        if app.state.ENABLE_RAG_HYBRID_SEARCH:
            return query_doc_with_hybrid_search(
                collection_name=form_data.collection_name,
                query=form_data.query,
                embeddings_function=app.state.EMBEDDING_FUNCTION,
                reranking_function=app.state.sentence_transformer_rf,
                k=form_data.k if form_data.k else app.state.TOP_K,
                r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
            )
        else:
            return query_doc(
                collection_name=form_data.collection_name,
                query=form_data.query,
                embeddings_function=app.state.EMBEDDING_FUNCTION,
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
405
    except Exception as e:
406
        log.exception(e)
407
408
409
410
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
411
412


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
413
414
415
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
416
    k: Optional[int] = None
417
    r: Optional[float] = None
Steven Kreitzer's avatar
Steven Kreitzer committed
418
    hybrid: Optional[bool] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
419
420


421
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
422
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
423
424
425
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
426
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        if app.state.ENABLE_RAG_HYBRID_SEARCH:
            return query_collection_with_hybrid_search(
                collection_names=form_data.collection_names,
                query=form_data.query,
                embeddings_function=app.state.EMBEDDING_FUNCTION,
                reranking_function=app.state.sentence_transformer_rf,
                k=form_data.k if form_data.k else app.state.TOP_K,
                r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
            )
        else:
            return query_collection(
                collection_names=form_data.collection_names,
                query=form_data.query,
                embeddings_function=app.state.EMBEDDING_FUNCTION,
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
443

444
445
446
447
448
449
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
450
451


452
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
453
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
454
455
456
457
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
        loader = WebBaseLoader(form_data.url)
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
458
459
460
461
462

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

463
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
464
465
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
466
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
467
468
            "filename": form_data.url,
        }
469
    except Exception as e:
470
        log.exception(e)
471
472
473
474
475
476
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


477
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
478

479
480
481
482
483
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
484

485
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
486
487

    if len(docs) > 0:
488
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
489
490
491
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
492
493
494


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
495
    text, metadata, collection_name, overwrite: bool = False
496
497
498
499
500
501
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
502
    docs = text_splitter.create_documents([text], metadatas=[metadata])
503
504
505
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
506
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
507
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
508

509
510
511
512
513
514
515
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
516
                    log.info(f"deleting existing collection {collection_name}")
517
518
                    CHROMA_CLIENT.delete_collection(name=collection_name)

519
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
520

Timothy J. Baek's avatar
Timothy J. Baek committed
521
        embedding_func = get_embedding_function(
Steven Kreitzer's avatar
Steven Kreitzer committed
522
523
524
525
526
527
528
529
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
530
        embeddings = embedding_func(embedding_texts)
531
532
533
534
535
536
537
538
539

        for batch in create_batches(
            api=CHROMA_CLIENT,
            ids=[str(uuid.uuid1()) for _ in texts],
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
540

541
        return True
542
    except Exception as e:
543
        log.exception(e)
544
545
546
547
548
549
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


550
551
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
600
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
601
602
603
604
605
606
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
607
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
608
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
609
610
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
611
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
612
613
        loader = UnstructuredEPubLoader(file_path)
    elif (
614
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
615
616
617
618
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
619
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
620
621
622
623
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
624
625
626
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
627
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
628
    else:
629
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
630
631
632
633
634
        known_type = False

    return loader, known_type


635
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
636
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
637
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
638
639
640
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
641
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
642

643
    log.info(f"file.content_type: {file.content_type}")
644
    try:
645
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
646
        filename = os.path.basename(unsanitized_filename)
647

Timothy J. Baek's avatar
Timothy J. Baek committed
648
        file_path = f"{UPLOAD_DIR}/{filename}"
649

650
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
651
        with open(file_path, "wb") as f:
652
653
654
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
655
656
657
658
659
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
660
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
661
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
662
663
664
665
666
667
668
669
670
671
672
673

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
674
675
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
676
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
677
            )
678
    except Exception as e:
679
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
680
681
682
683
684
685
686
687
688
689
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
690
691


692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
708
709
710
711
712
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
713
714
715
716
717
718
719
720
721
722

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


723
724
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
725
726
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
727
728
729
730
731
732
733
734
735
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
736
737
738
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
739
740
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
765
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
766
767
768
769
770
771
772
773
774
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
775
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
776
                    pass
777

778
        except Exception as e:
779
            log.exception(e)
780
781
782
783

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
784
@app.get("/reset/db")
785
786
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
787
788
789


@app.get("/reset")
790
791
792
793
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
794
        try:
795
796
797
798
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
799
        except Exception as e:
800
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
801

802
803
804
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
805
        log.exception(e)
806
807

    return True