Commit 1846c1e8 authored by Jannik Streidl's avatar Jannik Streidl
Browse files

choose embedding model when using docker

parent 4c3edd03
...@@ -30,10 +30,16 @@ ENV WEBUI_SECRET_KEY "" ...@@ -30,10 +30,16 @@ ENV WEBUI_SECRET_KEY ""
ENV SCARF_NO_ANALYTICS true ENV SCARF_NO_ANALYTICS true
ENV DO_NOT_TRACK true ENV DO_NOT_TRACK true
#Whisper TTS Settings # whisper TTS Settings
ENV WHISPER_MODEL="base" ENV WHISPER_MODEL="base"
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large"
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2"
WORKDIR /app/backend WORKDIR /app/backend
# install python dependencies # install python dependencies
...@@ -48,7 +54,9 @@ RUN apt-get update \ ...@@ -48,7 +54,9 @@ RUN apt-get update \
&& apt-get install -y pandoc netcat-openbsd \ && apt-get install -y pandoc netcat-openbsd \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" # preload embedding model
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])"
# preload tts model
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
......
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request,
Depends, Depends,
HTTPException, HTTPException,
status, status,
...@@ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
import os, shutil import os, shutil
from typing import List from typing import List
# from chromadb.utils import embedding_functions from chromadb.utils import embedding_functions
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
WebBaseLoader, WebBaseLoader,
...@@ -28,24 +27,19 @@ from langchain_community.document_loaders import ( ...@@ -28,24 +27,19 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader, UnstructuredExcelLoader,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
import uuid import uuid
import time
from utils.misc import calculate_sha256, calculate_sha256_string from utils.misc import calculate_sha256, calculate_sha256_string
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
# model_name=EMBED_MODEL
# )
app = FastAPI() app = FastAPI()
...@@ -78,11 +72,17 @@ def store_data_in_vector_db(data, collection_name) -> bool: ...@@ -78,11 +72,17 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
collection = CHROMA_CLIENT.create_collection(name=collection_name) if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef)
else:
# for local development use the default model
collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection.add( collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
) )
return True return True
except Exception as e: except Exception as e:
print(e) print(e)
...@@ -109,9 +109,17 @@ def query_doc( ...@@ -109,9 +109,17 @@ def query_doc(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
collection = CHROMA_CLIENT.get_collection( if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
name=form_data.collection_name, # if you use docker use the model from the environment variable
) collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=sentence_transformer_ef
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
)
result = collection.query(query_texts=[form_data.query], n_results=form_data.k) result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result return result
except Exception as e: except Exception as e:
...@@ -182,9 +190,18 @@ def query_collection( ...@@ -182,9 +190,18 @@ def query_collection(
for collection_name in form_data.collection_names: for collection_name in form_data.collection_names:
try: try:
collection = CHROMA_CLIENT.get_collection( if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
name=collection_name, # if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=sentence_transformer_ef
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
) )
result = collection.query( result = collection.query(
query_texts=[form_data.query], n_results=form_data.k query_texts=[form_data.query], n_results=form_data.k
) )
......
...@@ -128,7 +128,8 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": ...@@ -128,7 +128,8 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
#################################### ####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL")
CHROMA_CLIENT = chromadb.PersistentClient( CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False), settings=Settings(allow_reset=True, anonymized_telemetry=False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment