main.py 21.6 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
)
32
33
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

from pydantic import BaseModel
from typing import Optional
36
import mimetypes
37
import uuid
38
39
import json

40
import sentence_transformers
41

Timothy J. Baek's avatar
Timothy J. Baek committed
42
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
43

44
45
46
47
48
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
49

50
51
from apps.rag.utils import (
    query_embeddings_doc,
Steven Kreitzer's avatar
Steven Kreitzer committed
52
    query_embeddings_function,
53
54
    query_embeddings_collection,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
55

56
57
58
59
60
61
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
62
from utils.utils import get_current_user, get_admin_user
63
from config import (
64
    SRC_LOG_LEVELS,
65
66
    UPLOAD_DIR,
    DOCS_DIR,
67
    RAG_EMBEDDING_ENGINE,
68
    RAG_EMBEDDING_MODEL,
69
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Steven Kreitzer's avatar
Steven Kreitzer committed
70
71
    RAG_RERANKING_MODEL,
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
72
73
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
74
    DEVICE_TYPE,
75
76
77
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
78
    RAG_TEMPLATE,
79
)
80

81
82
from constants import ERROR_MESSAGES

83
84
85
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
86
87
app = FastAPI()

88
89

app.state.TOP_K = 4
Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
92
93


94
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
95
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
Steven Kreitzer's avatar
Steven Kreitzer committed
96
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
97
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
98

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
99
100
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
101

102
103
app.state.PDF_EXTRACT_IMAGES = False

104
105
106
if app.state.RAG_EMBEDDING_ENGINE == "":
    app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
        app.state.RAG_EMBEDDING_MODEL,
107
        device=DEVICE_TYPE,
108
        trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
109
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
110

Steven Kreitzer's avatar
Steven Kreitzer committed
111
112
113
114
115
116
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
    app.state.RAG_RERANKING_MODEL,
    device=DEVICE_TYPE,
    trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)

Timothy J. Baek's avatar
Timothy J. Baek committed
117

Timothy J. Baek's avatar
Timothy J. Baek committed
118
119
120
121
122
123
124
125
126
127
128
origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
129
class CollectionNameForm(BaseModel):
130
131
132
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
133
134
135
class StoreWebForm(CollectionNameForm):
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
136

Timothy J. Baek's avatar
Timothy J. Baek committed
137
138
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
139
140
141
142
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
143
        "template": app.state.RAG_TEMPLATE,
144
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
145
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
Steven Kreitzer's avatar
Steven Kreitzer committed
146
        "reranking_model": app.state.RAG_RERANKING_MODEL,
147
148
149
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
150
151
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
152
153
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
154
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
155
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
156
        "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
157
158
            "url": app.state.OPENAI_API_BASE_URL,
            "key": app.state.OPENAI_API_KEY,
159
        },
160
161
162
    }


Steven Kreitzer's avatar
Steven Kreitzer committed
163
164
165
166
167
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
    return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}


168
169
170
171
172
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


173
class EmbeddingModelUpdateForm(BaseModel):
174
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
175
    embedding_engine: str
176
177
178
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
@app.post("/embedding/update")
async def update_embedding_config(
181
182
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
183
184
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
185
    )
186
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
187
188
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine

189
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
190
191
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = None
192
193

            if form_data.openai_config != None:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
194
195
                app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.OPENAI_API_KEY = form_data.openai_config.key
Timothy J. Baek's avatar
Timothy J. Baek committed
196
        else:
197
198
199
200
            sentence_transformer_ef = sentence_transformers.SentenceTransformer(
                app.state.RAG_EMBEDDING_MODEL,
                device=DEVICE_TYPE,
                trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
201
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
202
203
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = sentence_transformer_ef
204

205
206
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
207
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
208
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
209
            "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
210
211
                "url": app.state.OPENAI_API_BASE_URL,
                "key": app.state.OPENAI_API_KEY,
212
            },
213
        }
214

215
216
217
218
219
220
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
221
222


Steven Kreitzer's avatar
Steven Kreitzer committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class RerankingModelUpdateForm(BaseModel):
    reranking_model: str
    

@app.post("/reranking/update")
async def update_reranking_config(
    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
    log.info(
        f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
    )
    try:
        app.state.RAG_RERANKING_MODEL = form_data.reranking_model
        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
            app.state.RAG_RERANKING_MODEL,
            device=DEVICE_TYPE,
        )

        return {
            "status": True,
            "reranking_model": app.state.RAG_RERANKING_MODEL,
        }
    except Exception as e:
        log.exception(f"Problem updating reranking model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
253
254
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
255
256
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
257
258
259
260
261
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
262
263
264
265
266
267
268
269
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


Timothy J. Baek's avatar
Timothy J. Baek committed
270
271
272
273
274
275
276
277
278
279
class ConfigUpdateForm(BaseModel):
    pdf_extract_images: bool
    chunk: ChunkParamUpdateForm


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
Timothy J. Baek's avatar
Timothy J. Baek committed
280
281
282

    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
283
284
285
286
287
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
288
    }
289
290


Timothy J. Baek's avatar
Timothy J. Baek committed
291
292
293
294
295
296
297
298
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


299
300
301
302
303
304
305
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
306
307


308
309
310
311
312
313
314
315
316
317
318
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
    template: Optional[str] = None


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
Timothy J. Baek's avatar
Timothy J. Baek committed
319
    return {"status": True, "template": app.state.RAG_TEMPLATE}
320
321


322
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
323
324
    collection_name: str
    query: str
325
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
326
327


328
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
329
def query_doc_handler(
330
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
331
332
    user=Depends(get_current_user),
):
333
    try:
Steven Kreitzer's avatar
Steven Kreitzer committed
334
335
336
337
338
339
340
        embeddings_function = query_embeddings_function(
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )
341

342
343
344
345
        return query_embeddings_doc(
            collection_name=form_data.collection_name,
            query=form_data.query,
            k=form_data.k if form_data.k else app.state.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
346
347
            embeddings_function=embeddings_function,
            reranking_function=app.state.sentence_transformer_rf,
348
        )
349
    except Exception as e:
350
        log.exception(e)
351
352
353
354
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
355
356


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
357
358
359
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
360
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
361
362


363
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
364
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
365
366
367
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
368
    try:
Steven Kreitzer's avatar
Steven Kreitzer committed
369
370
371
372
373
374
375
        embeddings_function = embeddings_function(
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )
376

377
378
        return query_embeddings_collection(
            collection_names=form_data.collection_names,
Steven Kreitzer's avatar
Steven Kreitzer committed
379
            query=form_data.query,
380
            k=form_data.k if form_data.k else app.state.TOP_K,
Steven Kreitzer's avatar
Steven Kreitzer committed
381
382
            embeddings_function=embeddings_function,
            reranking_function=app.state.sentence_transformer_rf,
383
        )
384
385
386
387
388
389
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
390
391


392
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
393
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
394
395
396
397
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
        loader = WebBaseLoader(form_data.url)
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
398
399
400
401
402

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

403
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
404
405
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
406
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
407
408
            "filename": form_data.url,
        }
409
    except Exception as e:
410
        log.exception(e)
411
412
413
414
415
416
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


417
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
418

419
420
421
422
423
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
424

425
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
426
427

    if len(docs) > 0:
428
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
429
430
431
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
432
433
434


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
435
    text, metadata, collection_name, overwrite: bool = False
436
437
438
439
440
441
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
442
    docs = text_splitter.create_documents([text], metadatas=[metadata])
443
444
445
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
446
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
447
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
448

449
450
451
452
453
454
455
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
456
                    log.info(f"deleting existing collection {collection_name}")
457
458
                    CHROMA_CLIENT.delete_collection(name=collection_name)

459
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
460

Steven Kreitzer's avatar
Steven Kreitzer committed
461
462
463
464
465
466
467
468
469
        embedding_func = query_embeddings_function(
            app.state.RAG_EMBEDDING_ENGINE,
            app.state.RAG_EMBEDDING_MODEL,
            app.state.sentence_transformer_ef,
            app.state.OPENAI_API_KEY,
            app.state.OPENAI_API_BASE_URL,
        )

        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
470
        if app.state.RAG_EMBEDDING_ENGINE == "":
Steven Kreitzer's avatar
Steven Kreitzer committed
471
472
            embeddings = embedding_func(embedding_texts)
        else:
473
            embeddings = [
Steven Kreitzer's avatar
Steven Kreitzer committed
474
                embedding_func(embedding_texts) for text in texts
475
476
477
478
479
480
481
482
483
484
            ]

        for batch in create_batches(
            api=CHROMA_CLIENT,
            ids=[str(uuid.uuid1()) for _ in texts],
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
485

486
        return True
487
    except Exception as e:
488
        log.exception(e)
489
490
491
492
493
494
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


495
496
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
545
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
546
547
548
549
550
551
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
552
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
553
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
554
555
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
556
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
557
558
        loader = UnstructuredEPubLoader(file_path)
    elif (
559
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
560
561
562
563
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
564
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
565
566
567
568
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
569
570
571
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
572
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
573
    else:
574
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
575
576
577
578
579
        known_type = False

    return loader, known_type


580
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
581
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
582
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
583
584
585
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
586
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
587

588
    log.info(f"file.content_type: {file.content_type}")
589
    try:
590
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
591
        filename = os.path.basename(unsanitized_filename)
592

Timothy J. Baek's avatar
Timothy J. Baek committed
593
        file_path = f"{UPLOAD_DIR}/{filename}"
594

595
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
596
        with open(file_path, "wb") as f:
597
598
599
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
600
601
602
603
604
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
605
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
606
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
607
608
609
610
611
612
613
614
615
616
617
618

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
619
620
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
621
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
622
            )
623
    except Exception as e:
624
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
625
626
627
628
629
630
631
632
633
634
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
635
636


637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
653
654
655
656
657
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
658
659
660
661
662
663
664
665
666
667

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


668
669
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
670
671
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
672
673
674
675
676
677
678
679
680
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
681
682
683
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
684
685
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
710
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
711
712
713
714
715
716
717
718
719
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
720
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
721
                    pass
722

723
        except Exception as e:
724
            log.exception(e)
725
726
727
728

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
729
@app.get("/reset/db")
730
731
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
732
733
734


@app.get("/reset")
735
736
737
738
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
739
        try:
740
741
742
743
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
744
        except Exception as e:
745
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
746

747
748
749
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
750
        log.exception(e)
751
752

    return True