Commit 784b369c authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

feat: chromadb vector store api

parent b2c9f6df
......@@ -5,4 +5,5 @@ uploads
.ipynb_checkpoints
*.db
_test
Pipfile
\ No newline at end of file
Pipfile
data/*
\ No newline at end of file
from fastapi import FastAPI, Request, Depends, HTTPException
from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH
from chromadb.utils import embedding_functions
from langchain.document_loaders import WebBaseLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from pydantic import BaseModel
from typing import Optional
import uuid
from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES
EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBED_MODEL
)
app = FastAPI()
......@@ -18,6 +34,84 @@ app.add_middleware(
)
class StoreWebForm(BaseModel):
url: str
collection_name: Optional[str] = "test"
def store_data_in_vector_db(data, collection_name):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
collection = CHROMA_CLIENT.create_collection(
name=collection_name, embedding_function=EMBEDDING_FUNC
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
@app.get("/")
async def get_status():
return {"status": True}
@app.get("/query/{collection_name}")
def query_collection(collection_name: str, query: str, k: Optional[int] = 4):
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
result = collection.query(query_texts=[query], n_results=k)
return result
@app.post("/web")
def store_web(form_data: StoreWebForm):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = WebBaseLoader(form_data.url)
data = loader.load()
store_data_in_vector_db(data, form_data.collection_name)
return {"status": True}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.post("/doc")
def store_doc(file: UploadFile = File(...)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
print(file)
file.filename = f"{uuid.uuid4()}-{file.filename}"
contents = file.file.read()
with open(f"./data/{file.filename}", "wb") as f:
f.write(contents)
f.close()
# loader = WebBaseLoader(form_data.url)
# data = loader.load()
# store_data_in_vector_db(data, form_data.collection_name)
return {"status": True}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def reset_vector_db():
CHROMA_CLIENT.reset()
return {"status": True}
from dotenv import load_dotenv, find_dotenv
from constants import ERROR_MESSAGES
import os
import chromadb
from secrets import token_bytes
from base64 import b64encode
import os
from constants import ERROR_MESSAGES
load_dotenv(find_dotenv("../.env"))
......@@ -19,8 +19,9 @@ ENV = os.environ.get("ENV", "dev")
# OLLAMA_API_BASE_URL
####################################
OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL",
"http://localhost:11434/api")
OLLAMA_API_BASE_URL = os.environ.get(
"OLLAMA_API_BASE_URL", "http://localhost:11434/api"
)
if ENV == "prod":
if OLLAMA_API_BASE_URL == "/ollama/api":
......@@ -56,3 +57,13 @@ WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t")
if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG
####################################
CHROMA_DATA_PATH = "./data/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2"
CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH)
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100
......@@ -6,7 +6,6 @@ class MESSAGES(str, Enum):
class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str:
return super().__str__()
......@@ -30,7 +29,10 @@ class ERROR_MESSAGES(str, Enum):
UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = (
"The requested action has been restricted as a security measure.")
"The requested action has been restricted as a security measure."
)
FILE_NOT_SENT = "FILE_NOT_SENT"
NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment