"vscode:/vscode.git/clone" did not exist on "9ad23fad27bbb5827b66d5382a0098dd27071db9"
Unverified Commit 85175470 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1555 from open-webui/dev

0.1.119
parents 0399a69b 375056f8
...@@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.119] - 2024-04-16
### Added
- **🌟 Enhanced RAG Embedding Support**: Ollama, and OpenAI models can now be used for RAG embedding model.
- **🔄 Seamless Integration**: Copy 'ollama run <model name>' directly from Ollama page to easily select and pull models.
- **🏷️ Tagging Feature**: Add tags to chats directly via the sidebar chat menu.
- **📱 Mobile Accessibility**: Swipe left and right on mobile to effortlessly open and close the sidebar.
- **🔍 Improved Navigation**: Admin panel now supports pagination for user list.
- **🌍 Additional Language Support**: Added Polish language support.
### Fixed
- **🌍 Language Enhancements**: Vietnamese and Spanish translations have been improved.
- **🔧 Helm Fixes**: Resolved issues with Helm trailing slash and manifest.json.
### Changed
- **🐳 Docker Optimization**: Updated docker image build process to utilize 'uv' for significantly faster builds compared to 'pip3'.
## [0.1.118] - 2024-04-10 ## [0.1.118] - 2024-04-10
### Added ### Added
......
...@@ -93,15 +93,16 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \ ...@@ -93,15 +93,16 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \
# install python dependencies # install python dependencies
COPY ./backend/requirements.txt ./requirements.txt COPY ./backend/requirements.txt ./requirements.txt
RUN if [ "$USE_CUDA" = "true" ]; then \ RUN pip3 install uv && \
if [ "$USE_CUDA" = "true" ]; then \
# If you use CUDA the whisper and embedding model will be downloaded on first use # If you use CUDA the whisper and embedding model will be downloaded on first use
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
pip3 install -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \
python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \
else \ else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
pip3 install -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \
python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \
fi fi
......
...@@ -28,6 +28,7 @@ from config import ( ...@@ -28,6 +28,7 @@ from config import (
UPLOAD_DIR, UPLOAD_DIR,
WHISPER_MODEL, WHISPER_MODEL,
WHISPER_MODEL_DIR, WHISPER_MODEL_DIR,
WHISPER_MODEL_AUTO_UPDATE,
DEVICE_TYPE, DEVICE_TYPE,
) )
...@@ -69,12 +70,24 @@ def transcribe( ...@@ -69,12 +70,24 @@ def transcribe(
f.write(contents) f.write(contents)
f.close() f.close()
model = WhisperModel( whisper_kwargs = {
WHISPER_MODEL, "model_size_or_path": WHISPER_MODEL,
device=whisper_device_type, "device": whisper_device_type,
compute_type="int8", "compute_type": "int8",
download_root=WHISPER_MODEL_DIR, "download_root": WHISPER_MODEL_DIR,
) "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
}
log.debug(f"whisper_kwargs: {whisper_kwargs}")
try:
model = WhisperModel(**whisper_kwargs)
except:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
segments, info = model.transcribe(file_path, beam_size=5) segments, info = model.transcribe(file_path, beam_size=5)
log.info( log.info(
......
...@@ -29,7 +29,13 @@ import base64 ...@@ -29,7 +29,13 @@ import base64
import json import json
import logging import logging
from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL,
COMFYUI_BASE_URL,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -48,7 +54,7 @@ app.add_middleware( ...@@ -48,7 +54,7 @@ app.add_middleware(
) )
app.state.ENGINE = "" app.state.ENGINE = ""
app.state.ENABLED = False app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.OPENAI_API_KEY = "" app.state.OPENAI_API_KEY = ""
app.state.MODEL = "" app.state.MODEL = ""
......
...@@ -612,8 +612,13 @@ async def generate_embeddings( ...@@ -612,8 +612,13 @@ async def generate_embeddings(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
if url_idx == None: if url_idx == None:
if form_data.model in app.state.MODELS: model = form_data.model
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
...@@ -649,6 +654,60 @@ async def generate_embeddings( ...@@ -649,6 +654,60 @@ async def generate_embeddings(
) )
def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None,
):
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx == None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status()
data = r.json()
log.info(f"generate_ollama_embeddings {data}")
if "embedding" in data:
return data["embedding"]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise error_detail
class GenerateCompletionForm(BaseModel): class GenerateCompletionForm(BaseModel):
model: str model: str
prompt: str prompt: str
...@@ -672,8 +731,13 @@ async def generate_completion( ...@@ -672,8 +731,13 @@ async def generate_completion(
): ):
if url_idx == None: if url_idx == None:
if form_data.model in app.state.MODELS: model = form_data.model
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
...@@ -770,8 +834,13 @@ async def generate_chat_completion( ...@@ -770,8 +834,13 @@ async def generate_chat_completion(
): ):
if url_idx == None: if url_idx == None:
if form_data.model in app.state.MODELS: model = form_data.model
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
...@@ -874,8 +943,13 @@ async def generate_openai_chat_completion( ...@@ -874,8 +943,13 @@ async def generate_openai_chat_completion(
): ):
if url_idx == None: if url_idx == None:
if form_data.model in app.state.MODELS: model = form_data.model
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
......
...@@ -39,13 +39,22 @@ import uuid ...@@ -39,13 +39,22 @@ import uuid
import json import json
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
from apps.web.models.documents import ( from apps.web.models.documents import (
Documents, Documents,
DocumentForm, DocumentForm,
DocumentResponse, DocumentResponse,
) )
from apps.rag.utils import query_doc, query_collection, get_embedding_model_path from apps.rag.utils import (
query_doc,
query_embeddings_doc,
query_collection,
query_embeddings_collection,
get_embedding_model_path,
generate_openai_embeddings,
)
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
...@@ -58,6 +67,7 @@ from config import ( ...@@ -58,6 +67,7 @@ from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
DEVICE_TYPE, DEVICE_TYPE,
...@@ -74,16 +84,21 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -74,16 +84,21 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI() app = FastAPI()
app.state.PDF_EXTRACT_IMAGES = False
app.state.TOP_K = 4
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
app.state.RAG_OPENAI_API_KEY = ""
app.state.PDF_EXTRACT_IMAGES = False
app.state.TOP_K = 4
app.state.sentence_transformer_ef = ( app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction( embedding_functions.SentenceTransformerEmbeddingFunction(
...@@ -121,45 +136,72 @@ async def get_status(): ...@@ -121,45 +136,72 @@ async def get_status():
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE, "template": app.state.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
} }
@app.get("/embedding/model") @app.get("/embedding")
async def get_embedding_model(user=Depends(get_admin_user)): async def get_embedding_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY,
},
} }
class OpenAIConfigForm(BaseModel):
url: str
key: str
class EmbeddingModelUpdateForm(BaseModel): class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str
embedding_model: str embedding_model: str
@app.post("/embedding/model/update") @app.post("/embedding/update")
async def update_embedding_model( async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
log.info( log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
try: try:
sentence_transformer_ef = ( app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=get_embedding_model_path(form_data.embedding_model, True),
device=DEVICE_TYPE,
)
)
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
app.state.sentence_transformer_ef = sentence_transformer_ef app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = None
if form_data.openai_config != None:
app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
else:
sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=get_embedding_model_path(
form_data.embedding_model, True
),
device=DEVICE_TYPE,
)
)
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = sentence_transformer_ef
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY,
},
} }
except Exception as e: except Exception as e:
...@@ -252,12 +294,37 @@ def query_doc_handler( ...@@ -252,12 +294,37 @@ def query_doc_handler(
): ):
try: try:
return query_doc( if app.state.RAG_EMBEDDING_ENGINE == "":
collection_name=form_data.collection_name, return query_doc(
query=form_data.query, collection_name=form_data.collection_name,
k=form_data.k if form_data.k else app.state.TOP_K, query=form_data.query,
embedding_function=app.state.sentence_transformer_ef, k=form_data.k if form_data.k else app.state.TOP_K,
) embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
)
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
return query_embeddings_doc(
collection_name=form_data.collection_name,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
...@@ -277,12 +344,45 @@ def query_collection_handler( ...@@ -277,12 +344,45 @@ def query_collection_handler(
form_data: QueryCollectionsForm, form_data: QueryCollectionsForm,
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
return query_collection( try:
collection_names=form_data.collection_names, if app.state.RAG_EMBEDDING_ENGINE == "":
query=form_data.query, return query_collection(
k=form_data.k if form_data.k else app.state.TOP_K, collection_names=form_data.collection_names,
embedding_function=app.state.sentence_transformer_ef, query=form_data.query,
) k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
)
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
return query_embeddings_collection(
collection_names=form_data.collection_names,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.post("/web") @app.post("/web")
...@@ -317,9 +417,11 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b ...@@ -317,9 +417,11 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
chunk_overlap=app.state.CHUNK_OVERLAP, chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
docs = text_splitter.split_documents(data) docs = text_splitter.split_documents(data)
if len(docs) > 0: if len(docs) > 0:
log.info(f"store_data_in_vector_db {docs}")
return store_docs_in_vector_db(docs, collection_name, overwrite), None return store_docs_in_vector_db(docs, collection_name, overwrite), None
else: else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
...@@ -338,6 +440,7 @@ def store_text_in_vector_db( ...@@ -338,6 +440,7 @@ def store_text_in_vector_db(
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
...@@ -349,18 +452,52 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -349,18 +452,52 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"deleting existing collection {collection_name}") log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name) CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection( if app.state.RAG_EMBEDDING_ENGINE == "":
name=collection_name,
embedding_function=app.state.sentence_transformer_ef, collection = CHROMA_CLIENT.create_collection(
) name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
documents=texts,
):
collection.add(*batch)
for batch in create_batches( else:
api=CHROMA_CLIENT, collection = CHROMA_CLIENT.create_collection(name=collection_name)
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas, if app.state.RAG_EMBEDDING_ENGINE == "ollama":
documents=texts, embeddings = [
): generate_ollama_embeddings(
collection.add(*batch) GenerateEmbeddingsForm(
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
)
)
for text in texts
]
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
embeddings = [
generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=text,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
for text in texts
]
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=embeddings,
documents=texts,
):
collection.add(*batch)
return True return True
except Exception as e: except Exception as e:
......
...@@ -2,10 +2,16 @@ import os ...@@ -2,10 +2,16 @@ import os
import re import re
import logging import logging
from typing import List from typing import List
import requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
...@@ -26,6 +32,24 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function): ...@@ -26,6 +32,24 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
raise e raise e
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
try:
# if you use docker use the model from the environment variable
log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
log.info(f"query_embeddings_doc:result {result}")
return result
except Exception as e:
raise e
def merge_and_sort_query_results(query_results, k): def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data # Initialize lists to store combined data
combined_ids = [] combined_ids = []
...@@ -96,14 +120,46 @@ def query_collection( ...@@ -96,14 +120,46 @@ def query_collection(
return merge_and_sort_query_results(results, k) return merge_and_sort_query_results(results, k)
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
results = []
log.info(f"query_embeddings_collection {query_embeddings}")
for collection_name in collection_names:
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k)
def rag_template(template: str, context: str, query: str): def rag_template(template: str, context: str, query: str):
template = template.replace("[context]", context) template = template.replace("[context]", context)
template = template.replace("[query]", query) template = template.replace("[query]", query)
return template return template
def rag_messages(docs, messages, template, k, embedding_function): def rag_messages(
log.debug(f"docs: {docs}") docs,
messages,
template,
k,
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
):
log.debug(
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}"
)
last_user_message_idx = None last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1): for i in range(len(messages) - 1, -1, -1):
...@@ -136,22 +192,57 @@ def rag_messages(docs, messages, template, k, embedding_function): ...@@ -136,22 +192,57 @@ def rag_messages(docs, messages, template, k, embedding_function):
context = None context = None
try: try:
if doc["type"] == "collection":
context = query_collection( if doc["type"] == "text":
collection_names=doc["collection_names"],
query=query,
k=k,
embedding_function=embedding_function,
)
elif doc["type"] == "text":
context = doc["content"] context = doc["content"]
else: else:
context = query_doc( if embedding_engine == "":
collection_name=doc["collection_name"], if doc["type"] == "collection":
query=query, context = query_collection(
k=k, collection_names=doc["collection_names"],
embedding_function=embedding_function, query=query,
) k=k,
embedding_function=embedding_function,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
embedding_function=embedding_function,
)
else:
if embedding_engine == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query_embeddings=query_embeddings,
k=k,
)
else:
context = query_embeddings_doc(
collection_name=doc["collection_name"],
query_embeddings=query_embeddings,
k=k,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context = None context = None
...@@ -230,3 +321,26 @@ def get_embedding_model_path( ...@@ -230,3 +321,26 @@ def get_embedding_model_path(
except Exception as e: except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}") log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model return embedding_model
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com"
):
try:
r = requests.post(
f"{url}/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": text, "model": model},
)
r.raise_for_status()
data = r.json()
if "data" in data:
return data["data"][0]["embedding"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
...@@ -18,6 +18,51 @@ from secrets import token_bytes ...@@ -18,6 +18,51 @@ from secrets import token_bytes
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
####################################
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"LITELLM",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
####################################
# Load .env file
####################################
try: try:
from dotenv import load_dotenv, find_dotenv from dotenv import load_dotenv, find_dotenv
...@@ -122,47 +167,6 @@ STATIC_DIR = str(Path(os.getenv("STATIC_DIR", "./static")).resolve()) ...@@ -122,47 +167,6 @@ STATIC_DIR = str(Path(os.getenv("STATIC_DIR", "./static")).resolve())
shutil.copyfile(f"{FRONTEND_BUILD_DIR}/favicon.png", f"{STATIC_DIR}/favicon.png") shutil.copyfile(f"{FRONTEND_BUILD_DIR}/favicon.png", f"{STATIC_DIR}/favicon.png")
####################################
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"LITELLM",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
#################################### ####################################
# CUSTOM_NAME # CUSTOM_NAME
#################################### ####################################
...@@ -401,6 +405,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": ...@@ -401,6 +405,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
...@@ -409,7 +416,7 @@ RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( ...@@ -409,7 +416,7 @@ RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
) )
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
if USE_CUDA.lower() == "true": if USE_CUDA.lower() == "true":
...@@ -446,11 +453,17 @@ Query: [query]""" ...@@ -446,11 +453,17 @@ Query: [query]"""
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
#################################### ####################################
# Images # Images
#################################### ####################################
ENABLE_IMAGE_GENERATION = (
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true"
)
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
...@@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
data["messages"], data["messages"],
rag_app.state.RAG_TEMPLATE, rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K, rag_app.state.TOP_K,
rag_app.state.RAG_EMBEDDING_ENGINE,
rag_app.state.RAG_EMBEDDING_MODEL,
rag_app.state.sentence_transformer_ef, rag_app.state.sentence_transformer_ef,
rag_app.state.RAG_OPENAI_API_KEY,
rag_app.state.RAG_OPENAI_API_BASE_URL,
) )
del data["docs"] del data["docs"]
......
...@@ -10,7 +10,7 @@ ollama ...@@ -10,7 +10,7 @@ ollama
{{- if .Values.ollama.externalHost }} {{- if .Values.ollama.externalHost }}
{{- printf .Values.ollama.externalHost }} {{- printf .Values.ollama.externalHost }}
{{- else }} {{- else }}
{{- printf "http://%s.%s.svc.cluster.local:%d/" (include "ollama.name" .) (.Release.Namespace) (.Values.ollama.service.port | int) }} {{- printf "http://%s.%s.svc.cluster.local:%d" (include "ollama.name" .) (.Release.Namespace) (.Values.ollama.service.port | int) }}
{{- end }} {{- end }}
{{- end }} {{- end }}
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.1.118", "version": "0.1.119",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "open-webui", "name": "open-webui",
"version": "0.1.118", "version": "0.1.119",
"dependencies": { "dependencies": {
"@sveltejs/adapter-node": "^1.3.1", "@sveltejs/adapter-node": "^1.3.1",
"async": "^3.2.5", "async": "^3.2.5",
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.1.118", "version": "0.1.119",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "vite dev --host", "dev": "vite dev --host",
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
<head> <head>
<meta charset="utf-8" /> <meta charset="utf-8" />
<link rel="icon" href="%sveltekit.assets%/favicon.png" /> <link rel="icon" href="%sveltekit.assets%/favicon.png" />
<link rel="manifest" href="%sveltekit.assets%/manifest.json" /> <link rel="manifest" href="%sveltekit.assets%/manifest.json" crossorigin="use-credentials" />
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1" /> <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1" />
<meta name="robots" content="noindex,nofollow" /> <meta name="robots" content="noindex,nofollow" />
<script> <script>
......
import { OLLAMA_API_BASE_URL } from '$lib/constants'; import { OLLAMA_API_BASE_URL } from '$lib/constants';
import { promptTemplate } from '$lib/utils';
export const getOllamaUrls = async (token: string = '') => { export const getOllamaUrls = async (token: string = '') => {
let error = null; let error = null;
...@@ -144,7 +145,7 @@ export const generateTitle = async ( ...@@ -144,7 +145,7 @@ export const generateTitle = async (
) => { ) => {
let error = null; let error = null;
template = template.replace(/{{prompt}}/g, prompt); template = promptTemplate(template, prompt);
console.log(template); console.log(template);
...@@ -219,6 +220,32 @@ export const generatePrompt = async (token: string = '', model: string, conversa ...@@ -219,6 +220,32 @@ export const generatePrompt = async (token: string = '', model: string, conversa
return res; return res;
}; };
export const generateEmbeddings = async (token: string = '', model: string, text: string) => {
let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/embeddings`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
prompt: text
})
}).catch((err) => {
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
export const generateTextCompletion = async (token: string = '', model: string, text: string) => { export const generateTextCompletion = async (token: string = '', model: string, text: string) => {
let error = null; let error = null;
......
import { OPENAI_API_BASE_URL } from '$lib/constants'; import { OPENAI_API_BASE_URL } from '$lib/constants';
import { promptTemplate } from '$lib/utils';
export const getOpenAIUrls = async (token: string = '') => { export const getOpenAIUrls = async (token: string = '') => {
let error = null; let error = null;
...@@ -273,7 +274,7 @@ export const generateTitle = async ( ...@@ -273,7 +274,7 @@ export const generateTitle = async (
) => { ) => {
let error = null; let error = null;
template = template.replace(/{{prompt}}/g, prompt); template = promptTemplate(template, prompt);
console.log(template); console.log(template);
......
...@@ -346,10 +346,10 @@ export const resetVectorDB = async (token: string) => { ...@@ -346,10 +346,10 @@ export const resetVectorDB = async (token: string) => {
return res; return res;
}; };
export const getEmbeddingModel = async (token: string) => { export const getEmbeddingConfig = async (token: string) => {
let error = null; let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, { const res = await fetch(`${RAG_API_BASE_URL}/embedding`, {
method: 'GET', method: 'GET',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
...@@ -373,14 +373,21 @@ export const getEmbeddingModel = async (token: string) => { ...@@ -373,14 +373,21 @@ export const getEmbeddingModel = async (token: string) => {
return res; return res;
}; };
type OpenAIConfigForm = {
key: string;
url: string;
};
type EmbeddingModelUpdateForm = { type EmbeddingModelUpdateForm = {
openai_config?: OpenAIConfigForm;
embedding_engine: string;
embedding_model: string; embedding_model: string;
}; };
export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => { export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
let error = null; let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, { const res = await fetch(`${RAG_API_BASE_URL}/embedding/update`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
<button <button
class="self-center" class="self-center"
on:click={() => { on:click={() => {
localStorage.version = $config.version;
show = false; show = false;
}} }}
> >
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import { synthesizeOpenAISpeech } from '$lib/apis/openai'; import { synthesizeOpenAISpeech } from '$lib/apis/openai';
import { imageGenerations } from '$lib/apis/images'; import { imageGenerations } from '$lib/apis/images';
import { import {
approximateToHumanReadable,
extractSentences, extractSentences,
revertSanitizedResponseContent, revertSanitizedResponseContent,
sanitizeResponseContent sanitizeResponseContent
...@@ -122,7 +123,10 @@ ...@@ -122,7 +123,10 @@
eval_count: ${message.info.eval_count ?? 'N/A'}<br/> eval_count: ${message.info.eval_count ?? 'N/A'}<br/>
eval_duration: ${ eval_duration: ${
Math.round(((message.info.eval_duration ?? 0) / 1000000) * 100) / 100 ?? 'N/A' Math.round(((message.info.eval_duration ?? 0) / 1000000) * 100) / 100 ?? 'N/A'
}ms</span>`, }ms<br/>
approximate_total: ${approximateToHumanReadable(
message.info.total_duration
)}</span>`,
allowHTML: true allowHTML: true
}); });
} }
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
: items; : items;
const pullModelHandler = async () => { const pullModelHandler = async () => {
const sanitizedModelTag = searchValue.trim(); const sanitizedModelTag = searchValue.trim().replace(/^ollama\s+(run|pull)\s+/, '');
console.log($MODEL_DOWNLOAD_POOL); console.log($MODEL_DOWNLOAD_POOL);
if ($MODEL_DOWNLOAD_POOL[sanitizedModelTag]) { if ($MODEL_DOWNLOAD_POOL[sanitizedModelTag]) {
......
...@@ -139,7 +139,7 @@ ...@@ -139,7 +139,7 @@
}; };
const pullModelHandler = async () => { const pullModelHandler = async () => {
const sanitizedModelTag = modelTag.trim(); const sanitizedModelTag = modelTag.trim().replace(/^ollama\s+(run|pull)\s+/, '');
if (modelDownloadStatus[sanitizedModelTag]) { if (modelDownloadStatus[sanitizedModelTag]) {
toast.error( toast.error(
$i18n.t(`Model '{{modelTag}}' is already in queue for downloading.`, { $i18n.t(`Model '{{modelTag}}' is already in queue for downloading.`, {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment