main.py 22.4 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils import embedding_functions
17
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
18

Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
21
22
23
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
24
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
25
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
26
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
27
28
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
29
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
31
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
32
)
33
34
35
36
from langchain.text_splitter import RecursiveCharacterTextSplitter

from pydantic import BaseModel
from typing import Optional
37
import mimetypes
38
import uuid
39
40
import json

41

Timothy J. Baek's avatar
Timothy J. Baek committed
42
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
43

44
45
46
47
48
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
49

50
51
52
53
54
55
from apps.rag.utils import (
    query_doc,
    query_embeddings_doc,
    query_collection,
    query_embeddings_collection,
    get_embedding_model_path,
56
    generate_openai_embeddings,
57
)
Timothy J. Baek's avatar
Timothy J. Baek committed
58

59
60
61
62
63
64
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
65
from utils.utils import get_current_user, get_admin_user
66
from config import (
67
    SRC_LOG_LEVELS,
68
69
    UPLOAD_DIR,
    DOCS_DIR,
70
    RAG_EMBEDDING_ENGINE,
71
    RAG_EMBEDDING_MODEL,
72
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
73
    DEVICE_TYPE,
74
75
76
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
77
    RAG_TEMPLATE,
78
)
79

80
81
from constants import ERROR_MESSAGES

82
83
84
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
85
86
app = FastAPI()

87
88

app.state.TOP_K = 4
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
91
92


93
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
94
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
95
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
96

97
98
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
app.state.RAG_OPENAI_API_KEY = ""
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
99

100
101
app.state.PDF_EXTRACT_IMAGES = False

102

103
104
app.state.sentence_transformer_ef = (
    embedding_functions.SentenceTransformerEmbeddingFunction(
105
106
107
        model_name=get_embedding_model_path(
            app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
        ),
108
        device=DEVICE_TYPE,
109
110
    )
)
Timothy J. Baek's avatar
Timothy J. Baek committed
111

Timothy J. Baek's avatar
Timothy J. Baek committed
112

Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
115
116
117
118
119
120
121
122
123
origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
124
class CollectionNameForm(BaseModel):
125
126
127
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
130
class StoreWebForm(CollectionNameForm):
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
131

Timothy J. Baek's avatar
Timothy J. Baek committed
132
133
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
134
135
136
137
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
138
        "template": app.state.RAG_TEMPLATE,
139
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
140
141
142
143
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
144
145
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
146
147
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
148
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
149
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
150
151
152
153
        "openai_config": {
            "url": app.state.RAG_OPENAI_API_BASE_URL,
            "key": app.state.RAG_OPENAI_API_KEY,
        },
154
155
156
    }


157
158
159
160
161
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


162
class EmbeddingModelUpdateForm(BaseModel):
163
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
164
    embedding_engine: str
165
166
167
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
168
169
@app.post("/embedding/update")
async def update_embedding_config(
170
171
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
172
173
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
174
    )
175
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
176
177
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine

178
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = None
181
182
183
184

            if form_data.openai_config != None:
                app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
Timothy J. Baek's avatar
Timothy J. Baek committed
185
186
187
188
189
190
191
192
        else:
            sentence_transformer_ef = (
                embedding_functions.SentenceTransformerEmbeddingFunction(
                    model_name=get_embedding_model_path(
                        form_data.embedding_model, True
                    ),
                    device=DEVICE_TYPE,
                )
193
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = sentence_transformer_ef
196

197
198
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
199
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
200
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
201
202
203
204
            "openai_config": {
                "url": app.state.RAG_OPENAI_API_BASE_URL,
                "key": app.state.RAG_OPENAI_API_KEY,
            },
205
        }
206

207
208
209
210
211
212
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214


Timothy J. Baek's avatar
Timothy J. Baek committed
215
216
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
217
218
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
219
220
221
222
223
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
224
225
226
227
228
229
230
231
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


Timothy J. Baek's avatar
Timothy J. Baek committed
232
233
234
235
236
237
238
239
240
241
class ConfigUpdateForm(BaseModel):
    pdf_extract_images: bool
    chunk: ChunkParamUpdateForm


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
Timothy J. Baek's avatar
Timothy J. Baek committed
242
243
244

    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
245
246
247
248
249
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
250
    }
251
252


Timothy J. Baek's avatar
Timothy J. Baek committed
253
254
255
256
257
258
259
260
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


261
262
263
264
265
266
267
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
268
269


270
271
272
273
274
275
276
277
278
279
280
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
    template: Optional[str] = None


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
Timothy J. Baek's avatar
Timothy J. Baek committed
281
    return {"status": True, "template": app.state.RAG_TEMPLATE}
282
283


284
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
285
286
    collection_name: str
    query: str
287
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
288
289


290
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
291
def query_doc_handler(
292
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
293
294
    user=Depends(get_current_user),
):
Timothy J. Baek's avatar
Timothy J. Baek committed
295

296
    try:
297
298
        if app.state.RAG_EMBEDDING_ENGINE == "":
            return query_doc(
299
                collection_name=form_data.collection_name,
300
                query=form_data.query,
301
                k=form_data.k if form_data.k else app.state.TOP_K,
302
                embedding_function=app.state.sentence_transformer_ef,
303
304
            )
        else:
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                query_embeddings = generate_ollama_embeddings(
                    GenerateEmbeddingsForm(
                        **{
                            "model": app.state.RAG_EMBEDDING_MODEL,
                            "prompt": form_data.query,
                        }
                    )
                )
            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
                query_embeddings = generate_openai_embeddings(
                    model=app.state.RAG_EMBEDDING_MODEL,
                    text=form_data.query,
                    key=app.state.RAG_OPENAI_API_KEY,
                    url=app.state.RAG_OPENAI_API_BASE_URL,
                )

            return query_embeddings_doc(
323
                collection_name=form_data.collection_name,
324
                query_embeddings=query_embeddings,
325
326
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
327

328
    except Exception as e:
329
        log.exception(e)
330
331
332
333
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
334
335


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
336
337
338
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
339
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
340
341


342
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
343
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
344
345
346
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
347
    try:
348
349
        if app.state.RAG_EMBEDDING_ENGINE == "":
            return query_collection(
350
                collection_names=form_data.collection_names,
351
                query=form_data.query,
352
                k=form_data.k if form_data.k else app.state.TOP_K,
353
                embedding_function=app.state.sentence_transformer_ef,
354
355
            )
        else:
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374

            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                query_embeddings = generate_ollama_embeddings(
                    GenerateEmbeddingsForm(
                        **{
                            "model": app.state.RAG_EMBEDDING_MODEL,
                            "prompt": form_data.query,
                        }
                    )
                )
            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
                query_embeddings = generate_openai_embeddings(
                    model=app.state.RAG_EMBEDDING_MODEL,
                    text=form_data.query,
                    key=app.state.RAG_OPENAI_API_KEY,
                    url=app.state.RAG_OPENAI_API_BASE_URL,
                )

            return query_embeddings_collection(
375
                collection_names=form_data.collection_names,
376
                query_embeddings=query_embeddings,
377
378
                k=form_data.k if form_data.k else app.state.TOP_K,
            )
379

380
381
382
383
384
385
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
386
387


388
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
389
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
390
391
392
393
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
        loader = WebBaseLoader(form_data.url)
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
394
395
396
397
398

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

399
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
400
401
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
402
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
403
404
            "filename": form_data.url,
        }
405
    except Exception as e:
406
        log.exception(e)
407
408
409
410
411
412
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


413
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
414

415
416
417
418
419
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
420

421
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
422
423

    if len(docs) > 0:
424
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
425
426
427
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
428
429
430


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
431
    text, metadata, collection_name, overwrite: bool = False
432
433
434
435
436
437
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
438
    docs = text_splitter.create_documents([text], metadatas=[metadata])
439
440
441
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
442
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
443
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
444

445
446
447
448
449
450
451
    texts = [doc.page_content for doc in docs]
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
452
                    log.info(f"deleting existing collection {collection_name}")
453
454
                    CHROMA_CLIENT.delete_collection(name=collection_name)

455
456
457
458
459
460
        if app.state.RAG_EMBEDDING_ENGINE == "":

            collection = CHROMA_CLIENT.create_collection(
                name=collection_name,
                embedding_function=app.state.sentence_transformer_ef,
            )
461
462
463
464
465

            for batch in create_batches(
                api=CHROMA_CLIENT,
                ids=[str(uuid.uuid1()) for _ in texts],
                metadatas=metadatas,
466
467
468
469
470
                documents=texts,
            ):
                collection.add(*batch)

        else:
471
472
            collection = CHROMA_CLIENT.create_collection(name=collection_name)

473
474
            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                embeddings = [
475
                    generate_ollama_embeddings(
Timothy J. Baek's avatar
Timothy J. Baek committed
476
                        GenerateEmbeddingsForm(
477
                            **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
Timothy J. Baek's avatar
Timothy J. Baek committed
478
                        )
479
480
                    )
                    for text in texts
481
482
483
484
485
486
487
488
489
490
491
                ]
            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
                embeddings = [
                    generate_openai_embeddings(
                        model=app.state.RAG_EMBEDDING_MODEL,
                        text=text,
                        key=app.state.RAG_OPENAI_API_KEY,
                        url=app.state.RAG_OPENAI_API_BASE_URL,
                    )
                    for text in texts
                ]
492

493
494
495
496
            for batch in create_batches(
                api=CHROMA_CLIENT,
                ids=[str(uuid.uuid1()) for _ in texts],
                metadatas=metadatas,
497
                embeddings=embeddings,
498
499
            ):
                collection.add(*batch)
500

501
        return True
502
    except Exception as e:
503
        log.exception(e)
504
505
506
507
508
509
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


510
511
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
560
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
561
562
563
564
565
566
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
567
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
568
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
569
570
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
571
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
572
573
        loader = UnstructuredEPubLoader(file_path)
    elif (
574
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
575
576
577
578
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
579
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
580
581
582
583
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
584
585
586
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
587
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
588
    else:
589
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
590
591
592
593
594
        known_type = False

    return loader, known_type


595
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
596
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
597
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
598
599
600
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
601
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
602

603
    log.info(f"file.content_type: {file.content_type}")
604
    try:
605
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
606
        filename = os.path.basename(unsanitized_filename)
607

Timothy J. Baek's avatar
Timothy J. Baek committed
608
        file_path = f"{UPLOAD_DIR}/{filename}"
609

610
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
611
        with open(file_path, "wb") as f:
612
613
614
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
615
616
617
618
619
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
620
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
621
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
622
623
624
625
626
627
628
629
630
631
632
633

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
634
635
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
636
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
637
            )
638
    except Exception as e:
639
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
640
641
642
643
644
645
646
647
648
649
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
650
651


652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
668
669
670
671
672
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
673
674
675
676
677
678
679
680
681
682

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


683
684
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
685
686
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
687
688
689
690
691
692
693
694
695
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
696
697
698
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
699
700
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
725
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
726
727
728
729
730
731
732
733
734
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
735
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
736
                    pass
737

738
        except Exception as e:
739
            log.exception(e)
740
741
742
743

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
744
@app.get("/reset/db")
745
746
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
747
748
749


@app.get("/reset")
750
751
752
753
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
754
        try:
755
756
757
758
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
759
        except Exception as e:
760
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
761

762
763
764
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
765
        log.exception(e)
766
767

    return True