main.py 22.4 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils import embedding_functions
17
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
18

Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
21
22
23
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
24
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
25
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
26
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
28
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
29
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
31
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
)
33
34
35
36
from langchain.text_splitter import RecursiveCharacterTextSplitter

from pydantic import BaseModel
from typing import Optional
37
import mimetypes
38
import uuid
39
40
import json

41

Timothy J. Baek's avatar
Timothy J. Baek committed
42
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
43

44
45
46
47
48
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
49

50
51
52
53
54
55
from apps.rag.utils import (
    query_doc,
    query_embeddings_doc,
    query_collection,
    query_embeddings_collection,
    get_embedding_model_path,
56
    generate_openai_embeddings,
57
)
Timothy J. Baek's avatar
Timothy J. Baek committed
58

59
60
61
62
63
64
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
65
from utils.utils import get_current_user, get_admin_user
66
from config import (
67
    SRC_LOG_LEVELS,
68
69
    UPLOAD_DIR,
    DOCS_DIR,
70
    RAG_EMBEDDING_ENGINE,
71
    RAG_EMBEDDING_MODEL,
72
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
73
    DEVICE_TYPE,
74
75
76
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
77
    RAG_TEMPLATE,
78
)
79

80
81
from constants import ERROR_MESSAGES

82
83
84
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
85
86
app = FastAPI()

87
88

app.state.TOP_K = 4
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
91
92


93
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
94
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
95
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
96

97
98
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
app.state.RAG_OPENAI_API_KEY = ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
99

100
101
app.state.PDF_EXTRACT_IMAGES = False

102

103
104
app.state.sentence_transformer_ef = (
    embedding_functions.SentenceTransformerEmbeddingFunction(
105
106
107
        model_name=get_embedding_model_path(
            app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
        ),
108
        device=DEVICE_TYPE,
109
110
    )
)
Timothy J. Baek's avatar
Timothy J. Baek committed
111

Timothy J. Baek's avatar
Timothy J. Baek committed
112

Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
115
116
117
118
119
120
121
122
123
origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
124
class CollectionNameForm(BaseModel):
125
126
127
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
130
class StoreWebForm(CollectionNameForm):
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
131

Timothy J. Baek's avatar
Timothy J. Baek committed
132
133
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
134
135
136
137
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
138
        "template": app.state.RAG_TEMPLATE,
139
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
140
141
142
143
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
144
145
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
146
147
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
148
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
149
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
150
151
152
153
        "openai_config": {
            "url": app.state.RAG_OPENAI_API_BASE_URL,
            "key": app.state.RAG_OPENAI_API_KEY,
        },
154
155
156
    }


157
158
159
160
161
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


162
class EmbeddingModelUpdateForm(BaseModel):
163
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
164
    embedding_engine: str
165
166
167
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
168
169
@app.post("/embedding/update")
async def update_embedding_config(
170
171
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
172
173
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
174
    )
175
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
176
177
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine

178
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = None
181
182
183
184

            if form_data.openai_config != None:
                app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
Timothy J. Baek's avatar
Timothy J. Baek committed
185
186
187
188
189
190
191
192
        else:
            sentence_transformer_ef = (
                embedding_functions.SentenceTransformerEmbeddingFunction(
                    model_name=get_embedding_model_path(
                        form_data.embedding_model, True
                    ),
                    device=DEVICE_TYPE,
                )
193
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = sentence_transformer_ef
196

197
198
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
199
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
200
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
201
202
203
204
            "openai_config": {
                "url": app.state.RAG_OPENAI_API_BASE_URL,
                "key": app.state.RAG_OPENAI_API_KEY,
            },
205
        }
206

207
208
209
210
211
212
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214


Timothy J. Baek's avatar
Timothy J. Baek committed
215
216
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
217
218
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
219
220
221
222
223
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
224
225
226
227
228
229
230
231
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


Timothy J. Baek's avatar
Timothy J. Baek committed
232
233
234
235
236
237
238
239
240
241
class ConfigUpdateForm(BaseModel):
    pdf_extract_images: bool
    chunk: ChunkParamUpdateForm


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
Timothy J. Baek's avatar
Timothy J. Baek committed
242
243
244

    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
245
246
247
248
249
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
250
    }
251
252


Timothy J. Baek's avatar
Timothy J. Baek committed
253
254
255
256
257
258
259
260
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


261
262
263
264
265
266
267
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
268
269


270
271
272
273
274
275
276
277
278
279
280
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
    template: Optional[str] = None


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
Timothy J. Baek's avatar
Timothy J. Baek committed
281
    return {"status": True, "template": app.state.RAG_TEMPLATE}
282
283


284
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
285
286
    collection_name: str
    query: str
287
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
288
289


290
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
291
def query_doc_handler(
292
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
293
294
    user=Depends(get_current_user),
):
Timothy J. Baek's avatar
Timothy J. Baek committed
295

296
    try:
297
298
        if app.state.RAG_EMBEDDING_ENGINE == "":
            return query_doc(
299
                collection_name=form_data.collection_name,
300
                query=form_data.query,
301
                k=form_data.k if form_data.k else app.state.TOP_K,
302
                embedding_function=app.state.sentence_transformer_ef,
303
304
            )
        else:
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                query_embeddings = generate_ollama_embeddings(
                    GenerateEmbeddingsForm(
                        **{
                            "model": app.state.RAG_EMBEDDING_MODEL,
                            "prompt": form_data.query,
                        }
                    )
                )
            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
                query_embeddings = generate_openai_embeddings(
                    model=app.state.RAG_EMBEDDING_MODEL,
                    text=form_data.query,
                    key=app.state.RAG_OPENAI_API_KEY,
                    url=app.state.RAG_OPENAI_API_BASE_URL,
                )

            return query_embeddings_doc(
323
                collection_name=form_data.collection_name,
324
                query_embeddings=query_embeddings,
325
326
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
327

328
    except Exception as e:
329
        log.exception(e)
330
331
332
333
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
334
335


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
336
337
338
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
339
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
340
341


342
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
343
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
344
345
346
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
347
    try:
348
349
        if app.state.RAG_EMBEDDING_ENGINE == "":
            return query_collection(
350
                collection_names=form_data.collection_names,
351
                query=form_data.query,
352
                k=form_data.k if form_data.k else app.state.TOP_K,
353
                embedding_function=app.state.sentence_transformer_ef,
354
355
            )
        else:
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374

            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                query_embeddings = generate_ollama_embeddings(
                    GenerateEmbeddingsForm(
                        **{
                            "model": app.state.RAG_EMBEDDING_MODEL,
                            "prompt": form_data.query,
                        }
                    )
                )
            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
                query_embeddings = generate_openai_embeddings(
                    model=app.state.RAG_EMBEDDING_MODEL,
                    text=form_data.query,
                    key=app.state.RAG_OPENAI_API_KEY,
                    url=app.state.RAG_OPENAI_API_BASE_URL,
                )

            return query_embeddings_collection(
375
                collection_names=form_data.collection_names,
376
                query_embeddings=query_embeddings,
377
378
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
379

380
381
382
383
384
385
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
386
387


388
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
389
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
390
391
392
393
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
        loader = WebBaseLoader(form_data.url)
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
394
395
396
397
398

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

399
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
400
401
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
402
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
403
404
            "filename": form_data.url,
        }
405
    except Exception as e:
406
        log.exception(e)
407
408
409
410
411
412
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


413
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
414

415
416
417
418
419
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
420

421
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
422
423

    if len(docs) > 0:
Timothy J. Baek's avatar
Timothy J. Baek committed
424
        log.info("store_data_in_vector_db", "store_docs_in_vector_db")
Timothy J. Baek's avatar
Timothy J. Baek committed
425
426
427
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
428
429
430


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
431
    text, metadata, collection_name, overwrite: bool = False
432
433
434
435
436
437
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
438
    docs = text_splitter.create_documents([text], metadatas=[metadata])
439
440
441
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
442
443
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
    log.info("store_docs_in_vector_db", docs, collection_name)
Timothy J. Baek's avatar
Timothy J. Baek committed
444

445
446
447
448
449
450
451
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
452
                    log.info(f"deleting existing collection {collection_name}")
453
454
                    CHROMA_CLIENT.delete_collection(name=collection_name)

455
456
457
458
459
460
        if app.state.RAG_EMBEDDING_ENGINE == "":

            collection = CHROMA_CLIENT.create_collection(
                name=collection_name,
                embedding_function=app.state.sentence_transformer_ef,
            )
461
462
463
464
465

            for batch in create_batches(
                api=CHROMA_CLIENT,
                ids=[str(uuid.uuid1()) for _ in texts],
                metadatas=metadatas,
466
467
468
469
470
471
472
                documents=texts,
            ):
                collection.add(*batch)

        else:
            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                embeddings = [
473
                    generate_ollama_embeddings(
Timothy J. Baek's avatar
Timothy J. Baek committed
474
                        GenerateEmbeddingsForm(
475
                            **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
Timothy J. Baek's avatar
Timothy J. Baek committed
476
                        )
477
478
                    )
                    for text in texts
479
480
481
482
483
484
485
486
487
488
489
                ]
            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
                embeddings = [
                    generate_openai_embeddings(
                        model=app.state.RAG_EMBEDDING_MODEL,
                        text=text,
                        key=app.state.RAG_OPENAI_API_KEY,
                        url=app.state.RAG_OPENAI_API_BASE_URL,
                    )
                    for text in texts
                ]
490

491
492
493
494
            for batch in create_batches(
                api=CHROMA_CLIENT,
                ids=[str(uuid.uuid1()) for _ in texts],
                metadatas=metadatas,
495
                embeddings=embeddings,
496
497
            ):
                collection.add(*batch)
498

499
        return True
500
    except Exception as e:
501
        log.exception(e)
502
503
504
505
506
507
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


508
509
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
558
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
559
560
561
562
563
564
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
565
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
566
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
567
568
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
569
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
570
571
        loader = UnstructuredEPubLoader(file_path)
    elif (
572
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
573
574
575
576
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
577
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
578
579
580
581
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
582
583
584
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
585
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
586
    else:
587
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
588
589
590
591
592
        known_type = False

    return loader, known_type


593
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
594
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
595
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
596
597
598
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
599
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
600

601
    log.info(f"file.content_type: {file.content_type}")
602
    try:
603
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
604
        filename = os.path.basename(unsanitized_filename)
605

Timothy J. Baek's avatar
Timothy J. Baek committed
606
        file_path = f"{UPLOAD_DIR}/{filename}"
607

608
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
609
        with open(file_path, "wb") as f:
610
611
612
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
613
614
615
616
617
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
618
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
619
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
620
621
622
623
624
625
626
627
628
629
630
631

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
632
633
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
634
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
635
            )
636
    except Exception as e:
637
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
638
639
640
641
642
643
644
645
646
647
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
648
649


650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
666
667
668
669
670
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
671
672
673
674
675
676
677
678
679
680

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


681
682
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
683
684
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
685
686
687
688
689
690
691
692
693
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
694
695
696
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
697
698
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
723
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
724
725
726
727
728
729
730
731
732
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
733
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
734
                    pass
735

736
        except Exception as e:
737
            log.exception(e)
738
739
740
741

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
742
@app.get("/reset/db")
743
744
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
745
746
747


@app.get("/reset")
748
749
750
751
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
752
        try:
753
754
755
756
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
757
        except Exception as e:
758
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
759

760
761
762
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
763
        log.exception(e)
764
765

    return True