main.py 22.4 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils import embedding_functions
17
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
18

Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
21
22
23
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
24
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
25
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
26
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
28
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
29
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
31
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
)
33
34
35
36
from langchain.text_splitter import RecursiveCharacterTextSplitter

from pydantic import BaseModel
from typing import Optional
37
import mimetypes
38
import uuid
39
40
import json

41

Timothy J. Baek's avatar
Timothy J. Baek committed
42
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
43

44
45
46
47
48
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
49

50
51
52
53
54
55
from apps.rag.utils import (
    query_doc,
    query_embeddings_doc,
    query_collection,
    query_embeddings_collection,
    get_embedding_model_path,
56
    generate_openai_embeddings,
57
)
Timothy J. Baek's avatar
Timothy J. Baek committed
58

59
60
61
62
63
64
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
65
from utils.utils import get_current_user, get_admin_user
66
from config import (
67
    SRC_LOG_LEVELS,
68
69
    UPLOAD_DIR,
    DOCS_DIR,
70
    RAG_EMBEDDING_ENGINE,
71
    RAG_EMBEDDING_MODEL,
72
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
73
    DEVICE_TYPE,
74
75
76
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
77
    RAG_TEMPLATE,
78
)
79

80
81
from constants import ERROR_MESSAGES

82
83
84
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
85
86
app = FastAPI()

87
88

app.state.TOP_K = 4
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
91
92


93
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
94
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
95
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
96

97
98
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
app.state.RAG_OPENAI_API_KEY = ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
99

100
101
app.state.PDF_EXTRACT_IMAGES = False

102

103
104
app.state.sentence_transformer_ef = (
    embedding_functions.SentenceTransformerEmbeddingFunction(
105
106
107
        model_name=get_embedding_model_path(
            app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
        ),
108
        device=DEVICE_TYPE,
109
110
    )
)
Timothy J. Baek's avatar
Timothy J. Baek committed
111

Timothy J. Baek's avatar
Timothy J. Baek committed
112

Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
115
116
117
118
119
120
121
122
123
origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
124
class CollectionNameForm(BaseModel):
125
126
127
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
130
class StoreWebForm(CollectionNameForm):
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
131

Timothy J. Baek's avatar
Timothy J. Baek committed
132
133
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
134
135
136
137
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
138
        "template": app.state.RAG_TEMPLATE,
139
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
140
141
142
143
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
144
145
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
146
147
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
148
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
149
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
150
151
152
153
        "openai_config": {
            "url": app.state.RAG_OPENAI_API_BASE_URL,
            "key": app.state.RAG_OPENAI_API_KEY,
        },
154
155
156
    }


157
158
159
160
161
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


162
class EmbeddingModelUpdateForm(BaseModel):
163
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
164
    embedding_engine: str
165
166
167
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
168
169
@app.post("/embedding/update")
async def update_embedding_config(
170
171
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
172
173
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
174
    )
175
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
176
177
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine

178
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = None
181
182
183
184

            if form_data.openai_config != None:
                app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
Timothy J. Baek's avatar
Timothy J. Baek committed
185
186
187
188
189
190
191
192
        else:
            sentence_transformer_ef = (
                embedding_functions.SentenceTransformerEmbeddingFunction(
                    model_name=get_embedding_model_path(
                        form_data.embedding_model, True
                    ),
                    device=DEVICE_TYPE,
                )
193
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = sentence_transformer_ef
196

197
198
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
199
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
200
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
201
202
203
204
            "openai_config": {
                "url": app.state.RAG_OPENAI_API_BASE_URL,
                "key": app.state.RAG_OPENAI_API_KEY,
            },
205
        }
206

207
208
209
210
211
212
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214


Timothy J. Baek's avatar
Timothy J. Baek committed
215
216
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
217
218
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
219
220
221
222
223
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
224
225
226
227
228
229
230
231
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


Timothy J. Baek's avatar
Timothy J. Baek committed
232
233
234
235
236
237
238
239
240
241
class ConfigUpdateForm(BaseModel):
    pdf_extract_images: bool
    chunk: ChunkParamUpdateForm


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
Timothy J. Baek's avatar
Timothy J. Baek committed
242
243
244

    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
245
246
247
248
249
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
250
    }
251
252


Timothy J. Baek's avatar
Timothy J. Baek committed
253
254
255
256
257
258
259
260
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


261
262
263
264
265
266
267
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
268
269


270
271
272
273
274
275
276
277
278
279
280
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
    template: Optional[str] = None


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
Timothy J. Baek's avatar
Timothy J. Baek committed
281
    return {"status": True, "template": app.state.RAG_TEMPLATE}
282
283


284
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
285
286
    collection_name: str
    query: str
287
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
288
289


290
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
291
def query_doc_handler(
292
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
293
294
    user=Depends(get_current_user),
):
Timothy J. Baek's avatar
Timothy J. Baek committed
295

296
    try:
297
298
        if app.state.RAG_EMBEDDING_ENGINE == "":
            return query_doc(
299
                collection_name=form_data.collection_name,
300
                query=form_data.query,
301
                k=form_data.k if form_data.k else app.state.TOP_K,
302
                embedding_function=app.state.sentence_transformer_ef,
303
304
            )
        else:
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                query_embeddings = generate_ollama_embeddings(
                    GenerateEmbeddingsForm(
                        **{
                            "model": app.state.RAG_EMBEDDING_MODEL,
                            "prompt": form_data.query,
                        }
                    )
                )
            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
                query_embeddings = generate_openai_embeddings(
                    model=app.state.RAG_EMBEDDING_MODEL,
                    text=form_data.query,
                    key=app.state.RAG_OPENAI_API_KEY,
                    url=app.state.RAG_OPENAI_API_BASE_URL,
                )

            return query_embeddings_doc(
323
                collection_name=form_data.collection_name,
324
                query_embeddings=query_embeddings,
325
326
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
327

328
    except Exception as e:
329
        log.exception(e)
330
331
332
333
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
334
335


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
336
337
338
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
339
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
340
341


342
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
343
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
344
345
346
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
347
    try:
348
349
        if app.state.RAG_EMBEDDING_ENGINE == "":
            return query_collection(
350
                collection_names=form_data.collection_names,
351
                query=form_data.query,
352
                k=form_data.k if form_data.k else app.state.TOP_K,
353
                embedding_function=app.state.sentence_transformer_ef,
354
355
            )
        else:
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374

            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                query_embeddings = generate_ollama_embeddings(
                    GenerateEmbeddingsForm(
                        **{
                            "model": app.state.RAG_EMBEDDING_MODEL,
                            "prompt": form_data.query,
                        }
                    )
                )
            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
                query_embeddings = generate_openai_embeddings(
                    model=app.state.RAG_EMBEDDING_MODEL,
                    text=form_data.query,
                    key=app.state.RAG_OPENAI_API_KEY,
                    url=app.state.RAG_OPENAI_API_BASE_URL,
                )

            return query_embeddings_collection(
375
                collection_names=form_data.collection_names,
376
                query_embeddings=query_embeddings,
377
378
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
379

380
381
382
383
384
385
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
386
387


388
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
389
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
390
391
392
393
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
        loader = WebBaseLoader(form_data.url)
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
394
395
396
397
398

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

399
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
400
401
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
402
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
403
404
            "filename": form_data.url,
        }
405
    except Exception as e:
406
        log.exception(e)
407
408
409
410
411
412
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


413
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
414

415
416
417
418
419
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
420

421
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
422
423

    if len(docs) > 0:
424
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
425
426
427
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
428
429
430


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
431
    text, metadata, collection_name, overwrite: bool = False
432
433
434
435
436
437
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
438
    docs = text_splitter.create_documents([text], metadatas=[metadata])
439
440
441
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
442
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
443
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
444

445
446
447
448
449
450
451
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
452
                    log.info(f"deleting existing collection {collection_name}")
453
454
                    CHROMA_CLIENT.delete_collection(name=collection_name)

455
456
457
458
459
460
        if app.state.RAG_EMBEDDING_ENGINE == "":

            collection = CHROMA_CLIENT.create_collection(
                name=collection_name,
                embedding_function=app.state.sentence_transformer_ef,
            )
461
462
463
464
465

            for batch in create_batches(
                api=CHROMA_CLIENT,
                ids=[str(uuid.uuid1()) for _ in texts],
                metadatas=metadatas,
466
467
468
469
470
                documents=texts,
            ):
                collection.add(*batch)

        else:
471
472
            collection = CHROMA_CLIENT.create_collection(name=collection_name)

473
474
            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                embeddings = [
475
                    generate_ollama_embeddings(
Timothy J. Baek's avatar
Timothy J. Baek committed
476
                        GenerateEmbeddingsForm(
477
                            **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
Timothy J. Baek's avatar
Timothy J. Baek committed
478
                        )
479
480
                    )
                    for text in texts
481
482
483
484
485
486
487
488
489
490
491
                ]
            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
                embeddings = [
                    generate_openai_embeddings(
                        model=app.state.RAG_EMBEDDING_MODEL,
                        text=text,
                        key=app.state.RAG_OPENAI_API_KEY,
                        url=app.state.RAG_OPENAI_API_BASE_URL,
                    )
                    for text in texts
                ]
492

493
494
495
496
            for batch in create_batches(
                api=CHROMA_CLIENT,
                ids=[str(uuid.uuid1()) for _ in texts],
                metadatas=metadatas,
497
                embeddings=embeddings,
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
498
                documents=texts,
499
500
            ):
                collection.add(*batch)
501

502
        return True
503
    except Exception as e:
504
        log.exception(e)
505
506
507
508
509
510
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


511
512
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
561
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
562
563
564
565
566
567
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
568
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
569
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
570
571
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
572
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
573
574
        loader = UnstructuredEPubLoader(file_path)
    elif (
575
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
576
577
578
579
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
580
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
581
582
583
584
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
585
586
587
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
588
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
589
    else:
590
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
591
592
593
594
595
        known_type = False

    return loader, known_type


596
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
597
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
598
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
599
600
601
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
602
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
603

604
    log.info(f"file.content_type: {file.content_type}")
605
    try:
606
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
607
        filename = os.path.basename(unsanitized_filename)
608

Timothy J. Baek's avatar
Timothy J. Baek committed
609
        file_path = f"{UPLOAD_DIR}/{filename}"
610

611
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
612
        with open(file_path, "wb") as f:
613
614
615
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
616
617
618
619
620
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
621
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
622
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
623
624
625
626
627
628
629
630
631
632
633
634

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
635
636
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
637
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
638
            )
639
    except Exception as e:
640
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
641
642
643
644
645
646
647
648
649
650
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
651
652


653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
669
670
671
672
673
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
674
675
676
677
678
679
680
681
682
683

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


684
685
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
686
687
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
688
689
690
691
692
693
694
695
696
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
697
698
699
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
700
701
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
726
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
727
728
729
730
731
732
733
734
735
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
736
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
737
                    pass
738

739
        except Exception as e:
740
            log.exception(e)
741
742
743
744

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
745
@app.get("/reset/db")
746
747
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
748
749
750


@app.get("/reset")
751
752
753
754
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
755
        try:
756
757
758
759
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
760
        except Exception as e:
761
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
762

763
764
765
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
766
        log.exception(e)
767
768

    return True