main.py 21.4 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
9
from fastapi import (
    FastAPI,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
10
from fastapi.middleware.cors import CORSMiddleware
11
import os, shutil, logging, re
12
13

from pathlib import Path
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
14
from typing import List
Timothy J. Baek's avatar
Timothy J. Baek committed
15

16
from chromadb.utils.batch_utils import create_batches
Timothy J. Baek's avatar
Timothy J. Baek committed
17

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
from langchain_community.document_loaders import (
    WebBaseLoader,
    TextLoader,
    PyPDFLoader,
    CSVLoader,
23
    BSHTMLLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
24
    Docx2txtLoader,
Dave Bauman's avatar
Dave Bauman committed
25
    UnstructuredEPubLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
    UnstructuredWordDocumentLoader,
    UnstructuredMarkdownLoader,
28
    UnstructuredXMLLoader,
Marclass's avatar
Marclass committed
29
    UnstructuredRSTLoader,
Marclass's avatar
Marclass committed
30
    UnstructuredExcelLoader,
Timothy J. Baek's avatar
Timothy J. Baek committed
31
)
32
33
34
35
from langchain.text_splitter import RecursiveCharacterTextSplitter

from pydantic import BaseModel
from typing import Optional
36
import mimetypes
37
import uuid
38
39
import json

40
import sentence_transformers
41

Timothy J. Baek's avatar
Timothy J. Baek committed
42
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
43

44
45
46
47
48
from apps.web.models.documents import (
    Documents,
    DocumentForm,
    DocumentResponse,
)
Jannik Streidl's avatar
Jannik Streidl committed
49

50
51
52
from apps.rag.utils import (
    query_embeddings_doc,
    query_embeddings_collection,
53
    generate_openai_embeddings,
54
)
Timothy J. Baek's avatar
Timothy J. Baek committed
55

56
57
58
59
60
61
from utils.misc import (
    calculate_sha256,
    calculate_sha256_string,
    sanitize_filename,
    extract_folders_after_data_docs,
)
62
from utils.utils import get_current_user, get_admin_user
63
from config import (
64
    SRC_LOG_LEVELS,
65
66
    UPLOAD_DIR,
    DOCS_DIR,
67
    RAG_EMBEDDING_ENGINE,
68
    RAG_EMBEDDING_MODEL,
69
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
    RAG_OPENAI_API_BASE_URL,
    RAG_OPENAI_API_KEY,
72
    DEVICE_TYPE,
73
74
75
    CHROMA_CLIENT,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
Timothy J. Baek's avatar
Timothy J. Baek committed
76
    RAG_TEMPLATE,
77
)
78

79
80
from constants import ERROR_MESSAGES

81
82
83
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])

Timothy J. Baek's avatar
Timothy J. Baek committed
84
85
app = FastAPI()

86
87

app.state.TOP_K = 4
Timothy J. Baek's avatar
Timothy J. Baek committed
88
89
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
90
91


92
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
93
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
94
app.state.RAG_TEMPLATE = RAG_TEMPLATE
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
95

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
96
97
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
98

99
100
app.state.PDF_EXTRACT_IMAGES = False

101
102
103
if app.state.RAG_EMBEDDING_ENGINE == "":
    app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
        app.state.RAG_EMBEDDING_MODEL,
104
        device=DEVICE_TYPE,
105
        trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
106
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
107

Timothy J. Baek's avatar
Timothy J. Baek committed
108

Timothy J. Baek's avatar
Timothy J. Baek committed
109
110
111
112
113
114
115
116
117
118
119
origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


Timothy J. Baek's avatar
Timothy J. Baek committed
120
class CollectionNameForm(BaseModel):
121
122
123
    collection_name: Optional[str] = "test"


Timothy J. Baek's avatar
Timothy J. Baek committed
124
125
126
class StoreWebForm(CollectionNameForm):
    url: str

Timothy J. Baek's avatar
Timothy J. Baek committed
127

Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
@app.get("/")
async def get_status():
Timothy J. Baek's avatar
Timothy J. Baek committed
130
131
132
133
    return {
        "status": True,
        "chunk_size": app.state.CHUNK_SIZE,
        "chunk_overlap": app.state.CHUNK_OVERLAP,
134
        "template": app.state.RAG_TEMPLATE,
135
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
136
137
138
139
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
140
141
@app.get("/embedding")
async def get_embedding_config(user=Depends(get_admin_user)):
142
143
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
144
        "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
145
        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
146
        "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
147
148
            "url": app.state.OPENAI_API_BASE_URL,
            "key": app.state.OPENAI_API_KEY,
149
        },
150
151
152
    }


153
154
155
156
157
class OpenAIConfigForm(BaseModel):
    url: str
    key: str


158
class EmbeddingModelUpdateForm(BaseModel):
159
    openai_config: Optional[OpenAIConfigForm] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
160
    embedding_engine: str
161
162
163
    embedding_model: str


Timothy J. Baek's avatar
Timothy J. Baek committed
164
165
@app.post("/embedding/update")
async def update_embedding_config(
166
167
    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
Self Denial's avatar
Self Denial committed
168
169
    log.info(
        f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
170
    )
171
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
172
173
        app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine

174
        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
175
176
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = None
177
178

            if form_data.openai_config != None:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
179
180
                app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
                app.state.OPENAI_API_KEY = form_data.openai_config.key
Timothy J. Baek's avatar
Timothy J. Baek committed
181
        else:
182
183
184
185
            sentence_transformer_ef = sentence_transformers.SentenceTransformer(
                app.state.RAG_EMBEDDING_MODEL,
                device=DEVICE_TYPE,
                trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
186
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
187
188
            app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
            app.state.sentence_transformer_ef = sentence_transformer_ef
189

190
191
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
192
            "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
193
            "embedding_model": app.state.RAG_EMBEDDING_MODEL,
194
            "openai_config": {
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
195
196
                "url": app.state.OPENAI_API_BASE_URL,
                "key": app.state.OPENAI_API_KEY,
197
            },
198
        }
199

200
201
202
203
204
205
    except Exception as e:
        log.exception(f"Problem updating embedding model: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
206
207


Timothy J. Baek's avatar
Timothy J. Baek committed
208
209
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
210
211
    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
212
213
214
215
216
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
217
218
219
220
221
222
223
224
    }


class ChunkParamUpdateForm(BaseModel):
    chunk_size: int
    chunk_overlap: int


Timothy J. Baek's avatar
Timothy J. Baek committed
225
226
227
228
229
230
231
232
233
234
class ConfigUpdateForm(BaseModel):
    pdf_extract_images: bool
    chunk: ChunkParamUpdateForm


@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
    app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
    app.state.CHUNK_SIZE = form_data.chunk.chunk_size
    app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
Timothy J. Baek's avatar
Timothy J. Baek committed
235
236
237

    return {
        "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
238
239
240
241
242
        "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
        "chunk": {
            "chunk_size": app.state.CHUNK_SIZE,
            "chunk_overlap": app.state.CHUNK_OVERLAP,
        },
Timothy J. Baek's avatar
Timothy J. Baek committed
243
    }
244
245


Timothy J. Baek's avatar
Timothy J. Baek committed
246
247
248
249
250
251
252
253
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
    }


254
255
256
257
258
259
260
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
    return {
        "status": True,
        "template": app.state.RAG_TEMPLATE,
        "k": app.state.TOP_K,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
261
262


263
264
265
266
267
268
269
270
271
272
273
class QuerySettingsForm(BaseModel):
    k: Optional[int] = None
    template: Optional[str] = None


@app.post("/query/settings/update")
async def update_query_settings(
    form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
    app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
    app.state.TOP_K = form_data.k if form_data.k else 4
Timothy J. Baek's avatar
Timothy J. Baek committed
274
    return {"status": True, "template": app.state.RAG_TEMPLATE}
275
276


277
class QueryDocForm(BaseModel):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
278
279
    collection_name: str
    query: str
280
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
281
282


283
@app.post("/query/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
284
def query_doc_handler(
285
    form_data: QueryDocForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
286
287
    user=Depends(get_current_user),
):
288
    try:
289
        if app.state.RAG_EMBEDDING_ENGINE == "":
290
291
292
293
294
295
296
297
298
299
            query_embeddings = app.state.sentence_transformer_ef.encode(
                form_data.query
            ).tolist()
        elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
            query_embeddings = generate_ollama_embeddings(
                GenerateEmbeddingsForm(
                    **{
                        "model": app.state.RAG_EMBEDDING_MODEL,
                        "prompt": form_data.query,
                    }
300
                )
301
302
303
304
305
306
307
            )
        elif app.state.RAG_EMBEDDING_ENGINE == "openai":
            query_embeddings = generate_openai_embeddings(
                model=app.state.RAG_EMBEDDING_MODEL,
                text=form_data.query,
                key=app.state.OPENAI_API_KEY,
                url=app.state.OPENAI_API_BASE_URL,
308
            )
309

310
311
312
313
314
315
316
        return query_embeddings_doc(
            collection_name=form_data.collection_name,
            query=form_data.query,
            query_embeddings=query_embeddings,
            k=form_data.k if form_data.k else app.state.TOP_K,
        )

317
    except Exception as e:
318
        log.exception(e)
319
320
321
322
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
323
324


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
325
326
327
class QueryCollectionsForm(BaseModel):
    collection_names: List[str]
    query: str
328
    k: Optional[int] = None
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
329
330


331
@app.post("/query/collection")
Timothy J. Baek's avatar
Timothy J. Baek committed
332
def query_collection_handler(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
333
334
335
    form_data: QueryCollectionsForm,
    user=Depends(get_current_user),
):
336
    try:
337
        if app.state.RAG_EMBEDDING_ENGINE == "":
338
339
340
341
342
343
344
345
346
347
            query_embeddings = app.state.sentence_transformer_ef.encode(
                form_data.query
            ).tolist()
        elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
            query_embeddings = generate_ollama_embeddings(
                GenerateEmbeddingsForm(
                    **{
                        "model": app.state.RAG_EMBEDDING_MODEL,
                        "prompt": form_data.query,
                    }
348
                )
349
350
351
352
353
354
355
            )
        elif app.state.RAG_EMBEDDING_ENGINE == "openai":
            query_embeddings = generate_openai_embeddings(
                model=app.state.RAG_EMBEDDING_MODEL,
                text=form_data.query,
                key=app.state.OPENAI_API_KEY,
                url=app.state.OPENAI_API_BASE_URL,
356
            )
357

358
359
360
361
362
363
        return query_embeddings_collection(
            collection_names=form_data.collection_names,
            query_embeddings=query_embeddings,
            k=form_data.k if form_data.k else app.state.TOP_K,
        )

364
365
366
367
368
369
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
370
371


372
@app.post("/web")
Timothy J. Baek's avatar
Timothy J. Baek committed
373
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
374
375
376
377
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
    try:
        loader = WebBaseLoader(form_data.url)
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
378
379
380
381
382

        collection_name = form_data.collection_name
        if collection_name == "":
            collection_name = calculate_sha256_string(form_data.url)[:63]

383
        store_data_in_vector_db(data, collection_name, overwrite=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
384
385
        return {
            "status": True,
Timothy J. Baek's avatar
Timothy J. Baek committed
386
            "collection_name": collection_name,
Timothy J. Baek's avatar
Timothy J. Baek committed
387
388
            "filename": form_data.url,
        }
389
    except Exception as e:
390
        log.exception(e)
391
392
393
394
395
396
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )


397
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
398

399
400
401
402
403
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
404

405
    docs = text_splitter.split_documents(data)
Timothy J. Baek's avatar
Timothy J. Baek committed
406
407

    if len(docs) > 0:
408
        log.info(f"store_data_in_vector_db {docs}")
Timothy J. Baek's avatar
Timothy J. Baek committed
409
410
411
        return store_docs_in_vector_db(docs, collection_name, overwrite), None
    else:
        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
412
413
414


def store_text_in_vector_db(
Timothy J. Baek's avatar
Timothy J. Baek committed
415
    text, metadata, collection_name, overwrite: bool = False
416
417
418
419
420
421
) -> bool:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=app.state.CHUNK_SIZE,
        chunk_overlap=app.state.CHUNK_OVERLAP,
        add_start_index=True,
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
422
    docs = text_splitter.create_documents([text], metadatas=[metadata])
423
424
425
    return store_docs_in_vector_db(docs, collection_name, overwrite)


Timothy J. Baek's avatar
Timothy J. Baek committed
426
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
427
    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
Timothy J. Baek's avatar
Timothy J. Baek committed
428

429
    texts = [doc.page_content for doc in docs]
430
431
    texts = list(map(lambda x: x.replace("\n", " "), texts))

432
433
434
435
436
437
    metadatas = [doc.metadata for doc in docs]

    try:
        if overwrite:
            for collection in CHROMA_CLIENT.list_collections():
                if collection_name == collection.name:
438
                    log.info(f"deleting existing collection {collection_name}")
439
440
                    CHROMA_CLIENT.delete_collection(name=collection_name)

441
        collection = CHROMA_CLIENT.create_collection(name=collection_name)
442

443
444
445
446
447
448
449
        if app.state.RAG_EMBEDDING_ENGINE == "":
            embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
        elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
            embeddings = [
                generate_ollama_embeddings(
                    GenerateEmbeddingsForm(
                        **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
450
                    )
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
                )
                for text in texts
            ]
        elif app.state.RAG_EMBEDDING_ENGINE == "openai":
            embeddings = [
                generate_openai_embeddings(
                    model=app.state.RAG_EMBEDDING_MODEL,
                    text=text,
                    key=app.state.OPENAI_API_KEY,
                    url=app.state.OPENAI_API_BASE_URL,
                )
                for text in texts
            ]

        for batch in create_batches(
            api=CHROMA_CLIENT,
            ids=[str(uuid.uuid1()) for _ in texts],
            metadatas=metadatas,
            embeddings=embeddings,
            documents=texts,
        ):
            collection.add(*batch)
473

474
        return True
475
    except Exception as e:
476
        log.exception(e)
477
478
479
480
481
482
        if e.__class__.__name__ == "UniqueConstraintError":
            return True

        return False


483
484
def get_loader(filename: str, file_content_type: str, file_path: str):
    file_ext = filename.split(".")[-1].lower()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    known_type = True

    known_source_ext = [
        "go",
        "py",
        "java",
        "sh",
        "bat",
        "ps1",
        "cmd",
        "js",
        "ts",
        "css",
        "cpp",
        "hpp",
        "h",
        "c",
        "cs",
        "sql",
        "log",
        "ini",
        "pl",
        "pm",
        "r",
        "dart",
        "dockerfile",
        "env",
        "php",
        "hs",
        "hsc",
        "lua",
        "nginxconf",
        "conf",
        "m",
        "mm",
        "plsql",
        "perl",
        "rb",
        "rs",
        "db2",
        "scala",
        "bash",
        "swift",
        "vue",
        "svelte",
    ]

    if file_ext == "pdf":
Timothy J. Baek's avatar
Timothy J. Baek committed
533
        loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
534
535
536
537
538
539
    elif file_ext == "csv":
        loader = CSVLoader(file_path)
    elif file_ext == "rst":
        loader = UnstructuredRSTLoader(file_path, mode="elements")
    elif file_ext == "xml":
        loader = UnstructuredXMLLoader(file_path)
540
    elif file_ext in ["htm", "html"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
541
        loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
542
543
    elif file_ext == "md":
        loader = UnstructuredMarkdownLoader(file_path)
544
    elif file_content_type == "application/epub+zip":
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
545
546
        loader = UnstructuredEPubLoader(file_path)
    elif (
547
        file_content_type
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
548
549
550
551
        == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
        or file_ext in ["doc", "docx"]
    ):
        loader = Docx2txtLoader(file_path)
552
    elif file_content_type in [
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
553
554
555
556
        "application/vnd.ms-excel",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    ] or file_ext in ["xls", "xlsx"]:
        loader = UnstructuredExcelLoader(file_path)
557
558
559
    elif file_ext in known_source_ext or (
        file_content_type and file_content_type.find("text/") >= 0
    ):
560
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
561
    else:
562
        loader = TextLoader(file_path, autodetect_encoding=True)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
563
564
565
566
567
        known_type = False

    return loader, known_type


568
@app.post("/doc")
Timothy J. Baek's avatar
Timothy J. Baek committed
569
def store_doc(
Timothy J. Baek's avatar
Timothy J. Baek committed
570
    collection_name: Optional[str] = Form(None),
Timothy J. Baek's avatar
Timothy J. Baek committed
571
572
573
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
574
    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
Timothy J. Baek's avatar
Timothy J. Baek committed
575

576
    log.info(f"file.content_type: {file.content_type}")
577
    try:
578
        unsanitized_filename = file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
579
        filename = os.path.basename(unsanitized_filename)
580

Timothy J. Baek's avatar
Timothy J. Baek committed
581
        file_path = f"{UPLOAD_DIR}/{filename}"
582

583
        contents = file.file.read()
Timothy J. Baek's avatar
Timothy J. Baek committed
584
        with open(file_path, "wb") as f:
585
586
587
            f.write(contents)
            f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
588
589
590
591
592
        f = open(file_path, "rb")
        if collection_name == None:
            collection_name = calculate_sha256(f)[:63]
        f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
593
        loader, known_type = get_loader(filename, file.content_type, file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
594
        data = loader.load()
Timothy J. Baek's avatar
Timothy J. Baek committed
595
596
597
598
599
600
601
602
603
604
605
606

        try:
            result = store_data_in_vector_db(data, collection_name)

            if result:
                return {
                    "status": True,
                    "collection_name": collection_name,
                    "filename": filename,
                    "known_type": known_type,
                }
        except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
607
608
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Timothy J. Baek's avatar
Timothy J. Baek committed
609
                detail=e,
Timothy J. Baek's avatar
Timothy J. Baek committed
610
            )
611
    except Exception as e:
612
        log.exception(e)
Dave Bauman's avatar
Dave Bauman committed
613
614
615
616
617
618
619
620
621
622
        if "No pandoc was found" in str(e):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
            )
        else:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERROR_MESSAGES.DEFAULT(e),
            )
623
624


625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
class TextRAGForm(BaseModel):
    name: str
    content: str
    collection_name: Optional[str] = None


@app.post("/text")
def store_text(
    form_data: TextRAGForm,
    user=Depends(get_current_user),
):

    collection_name = form_data.collection_name
    if collection_name == None:
        collection_name = calculate_sha256_string(form_data.content)

Timothy J. Baek's avatar
Timothy J. Baek committed
641
642
643
644
645
    result = store_text_in_vector_db(
        form_data.content,
        metadata={"name": form_data.name, "created_by": user.id},
        collection_name=collection_name,
    )
646
647
648
649
650
651
652
653
654
655

    if result:
        return {"status": True, "collection_name": collection_name}
    else:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ERROR_MESSAGES.DEFAULT(),
        )


656
657
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
658
659
    for path in Path(DOCS_DIR).rglob("./**/*"):
        try:
660
661
662
663
664
665
666
667
668
            if path.is_file() and not path.name.startswith("."):
                tags = extract_folders_after_data_docs(path)
                filename = path.name
                file_content_type = mimetypes.guess_type(path)

                f = open(path, "rb")
                collection_name = calculate_sha256(f)[:63]
                f.close()

Timothy J. Baek's avatar
Timothy J. Baek committed
669
670
671
                loader, known_type = get_loader(
                    filename, file_content_type[0], str(path)
                )
672
673
                data = loader.load()

Timothy J. Baek's avatar
Timothy J. Baek committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
                try:
                    result = store_data_in_vector_db(data, collection_name)

                    if result:
                        sanitized_filename = sanitize_filename(filename)
                        doc = Documents.get_doc_by_name(sanitized_filename)

                        if doc == None:
                            doc = Documents.insert_new_doc(
                                user.id,
                                DocumentForm(
                                    **{
                                        "name": sanitized_filename,
                                        "title": filename,
                                        "collection_name": collection_name,
                                        "filename": filename,
                                        "content": (
                                            json.dumps(
                                                {
                                                    "tags": list(
                                                        map(
                                                            lambda name: {"name": name},
                                                            tags,
                                                        )
698
                                                    )
Timothy J. Baek's avatar
Timothy J. Baek committed
699
700
701
702
703
704
705
706
707
                                                }
                                            )
                                            if len(tags)
                                            else "{}"
                                        ),
                                    }
                                ),
                            )
                except Exception as e:
708
                    log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
709
                    pass
710

711
        except Exception as e:
712
            log.exception(e)
713
714
715
716

    return True


Timothy J. Baek's avatar
Timothy J. Baek committed
717
@app.get("/reset/db")
718
719
def reset_vector_db(user=Depends(get_admin_user)):
    CHROMA_CLIENT.reset()
Timothy J. Baek's avatar
Timothy J. Baek committed
720
721
722


@app.get("/reset")
723
724
725
726
def reset(user=Depends(get_admin_user)) -> bool:
    folder = f"{UPLOAD_DIR}"
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
Timothy J. Baek's avatar
Timothy J. Baek committed
727
        try:
728
729
730
731
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
732
        except Exception as e:
733
            log.error("Failed to delete %s. Reason: %s" % (file_path, e))
Timothy J. Baek's avatar
Timothy J. Baek committed
734

735
736
737
    try:
        CHROMA_CLIENT.reset()
    except Exception as e:
738
        log.exception(e)
739
740

    return True