Unverified Commit 748cb7d4 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1654 from open-webui/dev

0.1.121
parents b3da09f5 348186c4
...@@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.121] - 2024-04-24
### Fixed
- **🔧 Translation Issues**: Addressed various translation discrepancies.
- **🔒 LiteLLM Security Fix**: Updated LiteLLM version to resolve a security vulnerability.
- **🖥️ HTML Tag Display**: Rectified the issue where the '< br >' tag wasn't displaying correctly.
- **🔗 WebSocket Connection**: Resolved the failure of WebSocket connection under HTTPS security for ComfyUI server.
- **📜 FileReader Optimization**: Implemented FileReader initialization per image in multi-file drag & drop to ensure reusability.
- **🏷️ Tag Display**: Corrected tag display inconsistencies.
- **📦 Archived Chat Styling**: Fixed styling issues in archived chat.
- **🔖 Safari Copy Button Bug**: Addressed the bug where the copy button failed to copy links in Safari.
## [0.1.120] - 2024-04-20 ## [0.1.120] - 2024-04-20
### Added ### Added
......
...@@ -8,8 +8,8 @@ ARG USE_CUDA_VER=cu121 ...@@ -8,8 +8,8 @@ ARG USE_CUDA_VER=cu121
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. # IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=all-MiniLM-L6-v2 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
######## WebUI frontend ######## ######## WebUI frontend ########
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
...@@ -98,13 +98,13 @@ RUN pip3 install uv && \ ...@@ -98,13 +98,13 @@ RUN pip3 install uv && \
# If you use CUDA the whisper and embedding model will be downloaded on first use # If you use CUDA the whisper and embedding model will be downloaded on first use
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
else \ else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
fi fi
......
...@@ -35,8 +35,8 @@ from config import ( ...@@ -35,8 +35,8 @@ from config import (
ENABLE_IMAGE_GENERATION, ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL, AUTOMATIC1111_BASE_URL,
COMFYUI_BASE_URL, COMFYUI_BASE_URL,
OPENAI_API_BASE_URL, IMAGES_OPENAI_API_BASE_URL,
OPENAI_API_KEY, IMAGES_OPENAI_API_KEY,
) )
...@@ -58,8 +58,8 @@ app.add_middleware( ...@@ -58,8 +58,8 @@ app.add_middleware(
app.state.ENGINE = "" app.state.ENGINE = ""
app.state.ENABLED = ENABLE_IMAGE_GENERATION app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = OPENAI_API_KEY app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.MODEL = "" app.state.MODEL = ""
...@@ -135,27 +135,33 @@ async def update_engine_url( ...@@ -135,27 +135,33 @@ async def update_engine_url(
} }
class OpenAIKeyUpdateForm(BaseModel): class OpenAIConfigUpdateForm(BaseModel):
url: str
key: str key: str
@app.get("/key") @app.get("/openai/config")
async def get_openai_key(user=Depends(get_admin_user)): async def get_openai_config(user=Depends(get_admin_user)):
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
}
@app.post("/key/update") @app.post("/openai/config/update")
async def update_openai_key( async def update_openai_config(
form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user) form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.key == "": if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key app.state.OPENAI_API_KEY = form_data.key
return { return {
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"status": True, "status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
} }
......
import logging import sys
from litellm.proxy.proxy_server import ProxyConfig, initialize from fastapi import FastAPI, Depends, HTTPException
from litellm.proxy.proxy_server import app from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
import logging
from fastapi import FastAPI, Request, Depends, status, Response from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
import json import json
import time
import requests
from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from utils.utils import get_http_authorization_cred, get_current_user from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, ENV from config import SRC_LOG_LEVELS, ENV
from constants import MESSAGES
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"]) log.setLevel(SRC_LOG_LEVELS["LITELLM"])
...@@ -20,81 +28,324 @@ log.setLevel(SRC_LOG_LEVELS["LITELLM"]) ...@@ -20,81 +28,324 @@ log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import ( from config import (
MODEL_FILTER_ENABLED, MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
DATA_DIR,
LITELLM_PROXY_PORT,
LITELLM_PROXY_HOST,
)
from litellm.utils import get_llm_provider
import asyncio
import subprocess
import yaml
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
) )
proxy_config = ProxyConfig() LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"
with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file)
async def config(): app.state.CONFIG = litellm_config
router, model_list, general_settings = await proxy_config.load_config(
router=None, config_file_path="./data/litellm/config.yaml"
)
await initialize(config="./data/litellm/config.yaml", telemetry=False) # Global variable to store the subprocess reference
background_process = None
async def startup(): async def run_background_process(command):
await config() global background_process
log.info("run_background_process")
try:
# Log the command to be executed
log.info(f"Executing command: {command}")
# Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec(
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
background_process = process
log.info("Subprocess started successfully.")
# Capture STDERR for debugging purposes
stderr_output = await process.stderr.read()
stderr_text = stderr_output.decode().strip()
if stderr_text:
log.info(f"Subprocess STDERR: {stderr_text}")
# log.info output line by line
async for line in process.stdout:
log.info(line.decode().strip())
# Wait for the process to finish
returncode = await process.wait()
log.info(f"Subprocess exited with return code {returncode}")
except Exception as e:
log.error(f"Failed to start subprocess: {e}")
raise # Optionally re-raise the exception if you want it to propagate
async def start_litellm_background():
log.info("start_litellm_background")
# Command to run in the background
command = [
"litellm",
"--port",
str(LITELLM_PROXY_PORT),
"--host",
LITELLM_PROXY_HOST,
"--telemetry",
"False",
"--config",
LITELLM_CONFIG_DIR,
]
await run_background_process(command)
async def shutdown_litellm_background():
log.info("shutdown_litellm_background")
global background_process
if background_process:
background_process.terminate()
await background_process.wait() # Ensure the process has terminated
log.info("Subprocess terminated")
background_process = None
@app.on_event("startup") @app.on_event("startup")
async def on_startup(): async def startup_event():
await startup() log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http") @app.get("/")
async def auth_middleware(request: Request, call_next): async def get_status():
auth_header = request.headers.get("Authorization", "") return {"status": True}
request.state.user = None
async def restart_litellm():
"""
Endpoint to restart the litellm background service.
"""
log.info("Requested restart of litellm service.")
try:
# Shut down the existing process if it is running
await shutdown_litellm_background()
log.info("litellm service shutdown complete.")
# Restart the background service
asyncio.create_task(start_litellm_background())
log.info("litellm service restart complete.")
return {
"status": "success",
"message": "litellm service restarted successfully.",
}
except Exception as e:
log.info(f"Error restarting litellm service: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
return await restart_litellm()
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return app.state.CONFIG
class LiteLLMConfigForm(BaseModel):
general_settings: Optional[dict] = None
litellm_settings: Optional[dict] = None
model_list: Optional[List[dict]] = None
router_settings: Optional[dict] = None
model_config = ConfigDict(protected_namespaces=())
@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
app.state.CONFIG = form_data.model_dump(exclude_none=True)
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return app.state.CONFIG
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
while not background_process:
await asyncio.sleep(0.1)
url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
r = None
try: try:
user = get_current_user(get_http_authorization_cred(auth_header)) r = requests.request(method="GET", url=f"{url}/models")
log.debug(f"user: {user}") r.raise_for_status()
request.state.user = user
data = r.json()
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
return data
except Exception as e: except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request) log.exception(e)
return response error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
return {
"data": [
{
"id": model["model_name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in app.state.CONFIG["model_list"]
],
"object": "list",
}
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
return {"data": app.state.CONFIG["model_list"]}
class AddLiteLLMModelForm(BaseModel):
model_name: str
litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
@app.post("/model/new")
async def add_model_to_config(
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
try:
get_llm_provider(model=form_data.model_name)
app.state.CONFIG["model_list"].append(form_data.model_dump())
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
class DeleteLiteLLMModelForm(BaseModel):
id: str
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request) @app.post("/model/delete")
user = request.state.user async def delete_model_from_config(
form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
app.state.CONFIG["model_list"] = [
model
for model in app.state.CONFIG["model_list"]
if model["model_name"] != form_data.id
]
if "/models" in request.url.path: with open(LITELLM_CONFIG_DIR, "w") as file:
if isinstance(response, StreamingResponse): yaml.dump(app.state.CONFIG, file)
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
data = json.loads(body.decode("utf-8")) await restart_litellm()
if app.state.MODEL_FILTER_ENABLED: return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Modified Flag
data["modified"] = True
return JSONResponse(content=data)
return response @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body()
url = f"http://localhost:{LITELLM_PROXY_PORT}"
app.add_middleware(ModifyModelsResponseMiddleware) target_url = f"{url}/{path}"
headers = {}
# headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
response_data = r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
...@@ -13,7 +13,6 @@ import os, shutil, logging, re ...@@ -13,7 +13,6 @@ import os, shutil, logging, re
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from chromadb.utils import embedding_functions
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
...@@ -38,6 +37,7 @@ import mimetypes ...@@ -38,6 +37,7 @@ import mimetypes
import uuid import uuid
import json import json
import sentence_transformers
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
...@@ -48,11 +48,8 @@ from apps.web.models.documents import ( ...@@ -48,11 +48,8 @@ from apps.web.models.documents import (
) )
from apps.rag.utils import ( from apps.rag.utils import (
query_doc,
query_embeddings_doc, query_embeddings_doc,
query_collection,
query_embeddings_collection, query_embeddings_collection,
get_embedding_model_path,
generate_openai_embeddings, generate_openai_embeddings,
) )
...@@ -69,7 +66,7 @@ from config import ( ...@@ -69,7 +66,7 @@ from config import (
DOCS_DIR, DOCS_DIR,
RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY, RAG_OPENAI_API_KEY,
DEVICE_TYPE, DEVICE_TYPE,
...@@ -101,15 +98,12 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY ...@@ -101,15 +98,12 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = False app.state.PDF_EXTRACT_IMAGES = False
if app.state.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = ( app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
embedding_functions.SentenceTransformerEmbeddingFunction( app.state.RAG_EMBEDDING_MODEL,
model_name=get_embedding_model_path(
app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
),
device=DEVICE_TYPE, device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
) )
)
origins = ["*"] origins = ["*"]
...@@ -185,13 +179,10 @@ async def update_embedding_config( ...@@ -185,13 +179,10 @@ async def update_embedding_config(
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key app.state.OPENAI_API_KEY = form_data.openai_config.key
else: else:
sentence_transformer_ef = ( sentence_transformer_ef = sentence_transformers.SentenceTransformer(
embedding_functions.SentenceTransformerEmbeddingFunction( app.state.RAG_EMBEDDING_MODEL,
model_name=get_embedding_model_path( device=DEVICE_TYPE,
form_data.embedding_model, True trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
),
device=DEVICE_TYPE,
)
) )
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = sentence_transformer_ef app.state.sentence_transformer_ef = sentence_transformer_ef
...@@ -294,39 +285,35 @@ def query_doc_handler( ...@@ -294,39 +285,35 @@ def query_doc_handler(
form_data: QueryDocForm, form_data: QueryDocForm,
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "": if app.state.RAG_EMBEDDING_ENGINE == "":
return query_doc( query_embeddings = app.state.sentence_transformer_ef.encode(
collection_name=form_data.collection_name, form_data.query
query=form_data.query, ).tolist()
k=form_data.k if form_data.k else app.state.TOP_K, elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
embedding_function=app.state.sentence_transformer_ef, query_embeddings = generate_ollama_embeddings(
) GenerateEmbeddingsForm(
else: **{
if app.state.RAG_EMBEDDING_ENGINE == "ollama": "model": app.state.RAG_EMBEDDING_MODEL,
query_embeddings = generate_ollama_embeddings( "prompt": form_data.query,
GenerateEmbeddingsForm( }
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
)
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
) )
)
return query_embeddings_doc( elif app.state.RAG_EMBEDDING_ENGINE == "openai":
collection_name=form_data.collection_name, query_embeddings = generate_openai_embeddings(
query_embeddings=query_embeddings, model=app.state.RAG_EMBEDDING_MODEL,
k=form_data.k if form_data.k else app.state.TOP_K, text=form_data.query,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
) )
return query_embeddings_doc(
collection_name=form_data.collection_name,
query=form_data.query,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
...@@ -348,37 +335,32 @@ def query_collection_handler( ...@@ -348,37 +335,32 @@ def query_collection_handler(
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "": if app.state.RAG_EMBEDDING_ENGINE == "":
return query_collection( query_embeddings = app.state.sentence_transformer_ef.encode(
collection_names=form_data.collection_names, form_data.query
query=form_data.query, ).tolist()
k=form_data.k if form_data.k else app.state.TOP_K, elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
embedding_function=app.state.sentence_transformer_ef, query_embeddings = generate_ollama_embeddings(
) GenerateEmbeddingsForm(
else: **{
"model": app.state.RAG_EMBEDDING_MODEL,
if app.state.RAG_EMBEDDING_ENGINE == "ollama": "prompt": form_data.query,
query_embeddings = generate_ollama_embeddings( }
GenerateEmbeddingsForm(
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
) )
elif app.state.RAG_EMBEDDING_ENGINE == "openai": )
query_embeddings = generate_openai_embeddings( elif app.state.RAG_EMBEDDING_ENGINE == "openai":
model=app.state.RAG_EMBEDDING_MODEL, query_embeddings = generate_openai_embeddings(
text=form_data.query, model=app.state.RAG_EMBEDDING_MODEL,
key=app.state.OPENAI_API_KEY, text=form_data.query,
url=app.state.OPENAI_API_BASE_URL, key=app.state.OPENAI_API_KEY,
) url=app.state.OPENAI_API_BASE_URL,
return query_embeddings_collection(
collection_names=form_data.collection_names,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
) )
return query_embeddings_collection(
collection_names=form_data.collection_names,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
...@@ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"store_docs_in_vector_db {docs} {collection_name}") log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
texts = list(map(lambda x: x.replace("\n", " "), texts))
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
...@@ -454,52 +438,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -454,52 +438,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"deleting existing collection {collection_name}") log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name) CHROMA_CLIENT.delete_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "": collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
for batch in create_batches( if app.state.RAG_EMBEDDING_ENGINE == "":
api=CHROMA_CLIENT, embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
ids=[str(uuid.uuid1()) for _ in texts], elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
metadatas=metadatas, embeddings = [
documents=texts, generate_ollama_embeddings(
): GenerateEmbeddingsForm(
collection.add(*batch) **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
else:
collection = CHROMA_CLIENT.create_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
embeddings = [
generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
)
)
for text in texts
]
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
embeddings = [
generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=text,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
) )
for text in texts )
] for text in texts
]
for batch in create_batches( elif app.state.RAG_EMBEDDING_ENGINE == "openai":
api=CHROMA_CLIENT, embeddings = [
ids=[str(uuid.uuid1()) for _ in texts], generate_openai_embeddings(
metadatas=metadatas, model=app.state.RAG_EMBEDDING_MODEL,
embeddings=embeddings, text=text,
documents=texts, key=app.state.OPENAI_API_KEY,
): url=app.state.OPENAI_API_BASE_URL,
collection.add(*batch) )
for text in texts
]
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=embeddings,
documents=texts,
):
collection.add(*batch)
return True return True
except Exception as e: except Exception as e:
......
import os
import re
import logging import logging
from typing import List
import requests import requests
from typing import List
from huggingface_hub import snapshot_download from apps.ollama.main import (
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm generate_ollama_embeddings,
GenerateEmbeddingsForm,
)
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
...@@ -16,29 +15,12 @@ log = logging.getLogger(__name__) ...@@ -16,29 +15,12 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_doc(collection_name: str, query: str, k: int, embedding_function): def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
return result
except Exception as e:
raise e
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
try: try:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
log.info(f"query_embeddings_doc {query_embeddings}") log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(name=collection_name)
name=collection_name,
)
result = collection.query( result = collection.query(
query_embeddings=[query_embeddings], query_embeddings=[query_embeddings],
n_results=k, n_results=k,
...@@ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k): ...@@ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k):
return merged_query_results return merged_query_results
def query_collection( def query_embeddings_collection(
collection_names: List[str], query: str, k: int, embedding_function collection_names: List[str], query: str, query_embeddings, k: int
): ):
results = []
for collection_name in collection_names:
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k)
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
results = [] results = []
log.info(f"query_embeddings_collection {query_embeddings}") log.info(f"query_embeddings_collection {query_embeddings}")
for collection_name in collection_names: for collection_name in collection_names:
try: try:
collection = CHROMA_CLIENT.get_collection(name=collection_name) result = query_embeddings_doc(
collection_name=collection_name,
result = collection.query( query=query,
query_embeddings=[query_embeddings], query_embeddings=query_embeddings,
n_results=k, k=k,
) )
results.append(result) results.append(result)
except: except:
...@@ -197,51 +156,38 @@ def rag_messages( ...@@ -197,51 +156,38 @@ def rag_messages(
context = doc["content"] context = doc["content"]
else: else:
if embedding_engine == "": if embedding_engine == "":
if doc["type"] == "collection": query_embeddings = embedding_function.encode(query).tolist()
context = query_collection( elif embedding_engine == "ollama":
collection_names=doc["collection_names"], query_embeddings = generate_ollama_embeddings(
query=query, GenerateEmbeddingsForm(
k=k, **{
embedding_function=embedding_function, "model": embedding_model,
) "prompt": query,
else: }
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
embedding_function=embedding_function,
) )
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query=query,
query_embeddings=query_embeddings,
k=k,
)
else: else:
if embedding_engine == "ollama": context = query_embeddings_doc(
query_embeddings = generate_ollama_embeddings( collection_name=doc["collection_name"],
GenerateEmbeddingsForm( query=query,
**{ query_embeddings=query_embeddings,
"model": embedding_model, k=k,
"prompt": query, )
}
)
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query_embeddings=query_embeddings,
k=k,
)
else:
context = query_embeddings_doc(
collection_name=doc["collection_name"],
query_embeddings=query_embeddings,
k=k,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
...@@ -283,46 +229,6 @@ def rag_messages( ...@@ -283,46 +229,6 @@ def rag_messages(
return messages return messages
def get_embedding_model_path(
embedding_model: str, update_embedding_model: bool = False
):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_embedding_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"embedding_model: {embedding_model}")
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(embedding_model)
or ("\\" in embedding_model or embedding_model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return embedding_model
elif "/" not in embedding_model:
# Set valid repo_id for model short-name
embedding_model = "sentence-transformers" + "/" + embedding_model
snapshot_kwargs["repo_id"] = embedding_model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
return embedding_model_repo_path
except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model
def generate_openai_embeddings( def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1" model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
): ):
......
...@@ -28,7 +28,7 @@ from apps.web.models.tags import ( ...@@ -28,7 +28,7 @@ from apps.web.models.tags import (
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
...@@ -79,6 +79,11 @@ async def get_all_user_chats(user=Depends(get_current_user)): ...@@ -79,6 +79,11 @@ async def get_all_user_chats(user=Depends(get_current_user)):
@router.get("/all/db", response_model=List[ChatResponse]) @router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_admin_user)): async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats() for chat in Chats.get_all_chats()
......
...@@ -91,7 +91,11 @@ async def download_chat_as_pdf( ...@@ -91,7 +91,11 @@ async def download_chat_as_pdf(
@router.get("/db/download") @router.get("/db/download")
async def download_db(user=Depends(get_admin_user)): async def download_db(user=Depends(get_admin_user)):
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return FileResponse( return FileResponse(
f"{DATA_DIR}/webui.db", f"{DATA_DIR}/webui.db",
media_type="application/octet-stream", media_type="application/octet-stream",
......
...@@ -382,6 +382,8 @@ MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] ...@@ -382,6 +382,8 @@ MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
#################################### ####################################
# WEBUI_VERSION # WEBUI_VERSION
#################################### ####################################
...@@ -416,18 +418,19 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": ...@@ -416,18 +418,19 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
#################################### ####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") RAG_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
) )
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
...@@ -484,9 +487,24 @@ AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") ...@@ -484,9 +487,24 @@ AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
IMAGES_OPENAI_API_BASE_URL = os.getenv(
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
)
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
#################################### ####################################
# Audio # Audio
#################################### ####################################
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY) AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
####################################
# LiteLLM
####################################
LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
...@@ -3,6 +3,10 @@ from enum import Enum ...@@ -3,6 +3,10 @@ from enum import Enum
class MESSAGES(str, Enum): class MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}" DEFAULT = lambda msg="": f"{msg if msg else ''}"
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
MODEL_DELETED = (
lambda model="": f"The model '{model}' has been deleted successfully."
)
class WEBHOOK_MESSAGES(str, Enum): class WEBHOOK_MESSAGES(str, Enum):
......
...@@ -20,12 +20,17 @@ from starlette.middleware.base import BaseHTTPMiddleware ...@@ -20,12 +20,17 @@ from starlette.middleware.base import BaseHTTPMiddleware
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app
from apps.litellm.main import app as litellm_app, startup as litellm_app_startup from apps.litellm.main import (
app as litellm_app,
start_litellm_background,
shutdown_litellm_background,
)
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
import asyncio
from pydantic import BaseModel from pydantic import BaseModel
from typing import List from typing import List
...@@ -47,6 +52,7 @@ from config import ( ...@@ -47,6 +52,7 @@ from config import (
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
WEBHOOK_URL, WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -170,7 +176,7 @@ async def check_url(request: Request, call_next): ...@@ -170,7 +176,7 @@ async def check_url(request: Request, call_next):
@app.on_event("startup") @app.on_event("startup")
async def on_startup(): async def on_startup():
await litellm_app_startup() asyncio.create_task(start_litellm_background())
app.mount("/api/v1", webui_app) app.mount("/api/v1", webui_app)
...@@ -202,6 +208,7 @@ async def get_app_config(): ...@@ -202,6 +208,7 @@ async def get_app_config():
"default_models": webui_app.state.DEFAULT_MODELS, "default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
} }
...@@ -315,3 +322,8 @@ app.mount( ...@@ -315,3 +322,8 @@ app.mount(
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
name="spa-static-files", name="spa-static-files",
) )
@app.on_event("shutdown")
async def shutdown_event():
await shutdown_litellm_background()
...@@ -17,7 +17,9 @@ peewee ...@@ -17,7 +17,9 @@ peewee
peewee-migrate peewee-migrate
bcrypt bcrypt
litellm==1.30.7 litellm==1.35.17
litellm[proxy]==1.35.17
boto3 boto3
argon2-cffi argon2-cffi
...@@ -25,6 +27,7 @@ apscheduler ...@@ -25,6 +27,7 @@ apscheduler
google-generativeai google-generativeai
langchain langchain
langchain-chroma
langchain-community langchain-community
fake_useragent fake_useragent
chromadb chromadb
...@@ -43,6 +46,7 @@ opencv-python-headless ...@@ -43,6 +46,7 @@ opencv-python-headless
rapidocr-onnxruntime rapidocr-onnxruntime
fpdf2 fpdf2
rank_bm25
faster-whisper faster-whisper
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.1.120", "version": "0.1.121",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "open-webui", "name": "open-webui",
"version": "0.1.120", "version": "0.1.121",
"dependencies": { "dependencies": {
"@sveltejs/adapter-node": "^1.3.1", "@sveltejs/adapter-node": "^1.3.1",
"async": "^3.2.5", "async": "^3.2.5",
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.1.120", "version": "0.1.121",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "vite dev --host", "dev": "vite dev --host",
......
...@@ -72,10 +72,10 @@ export const updateImageGenerationConfig = async ( ...@@ -72,10 +72,10 @@ export const updateImageGenerationConfig = async (
return res; return res;
}; };
export const getOpenAIKey = async (token: string = '') => { export const getOpenAIConfig = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/key`, { const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
...@@ -101,13 +101,13 @@ export const getOpenAIKey = async (token: string = '') => { ...@@ -101,13 +101,13 @@ export const getOpenAIKey = async (token: string = '') => {
throw error; throw error;
} }
return res.OPENAI_API_KEY; return res;
}; };
export const updateOpenAIKey = async (token: string = '', key: string) => { export const updateOpenAIConfig = async (token: string = '', url: string, key: string) => {
let error = null; let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, { const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config/update`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
...@@ -115,6 +115,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => { ...@@ -115,6 +115,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => {
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
}, },
body: JSON.stringify({ body: JSON.stringify({
url: url,
key: key key: key
}) })
}) })
...@@ -136,7 +137,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => { ...@@ -136,7 +137,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => {
throw error; throw error;
} }
return res.OPENAI_API_KEY; return res;
}; };
export const getImageGenerationEngineUrls = async (token: string = '') => { export const getImageGenerationEngineUrls = async (token: string = '') => {
......
type TextStreamUpdate = {
done: boolean;
value: string;
};
// createOpenAITextStream takes a ReadableStreamDefaultReader from an SSE response,
// and returns an async generator that emits delta updates with large deltas chunked into random sized chunks
export async function createOpenAITextStream(
messageStream: ReadableStreamDefaultReader,
splitLargeDeltas: boolean
): Promise<AsyncGenerator<TextStreamUpdate>> {
let iterator = openAIStreamToIterator(messageStream);
if (splitLargeDeltas) {
iterator = streamLargeDeltasAsRandomChunks(iterator);
}
return iterator;
}
async function* openAIStreamToIterator(
reader: ReadableStreamDefaultReader
): AsyncGenerator<TextStreamUpdate> {
while (true) {
const { value, done } = await reader.read();
if (done) {
yield { done: true, value: '' };
break;
}
const lines = value.split('\n');
for (const line of lines) {
if (line !== '') {
console.log(line);
if (line === 'data: [DONE]') {
yield { done: true, value: '' };
} else {
const data = JSON.parse(line.replace(/^data: /, ''));
console.log(data);
yield { done: false, value: data.choices[0].delta.content ?? '' };
}
}
}
}
}
// streamLargeDeltasAsRandomChunks will chunk large deltas (length > 5) into random sized chunks between 1-3 characters
// This is to simulate a more fluid streaming, even though some providers may send large chunks of text at once
async function* streamLargeDeltasAsRandomChunks(
iterator: AsyncGenerator<TextStreamUpdate>
): AsyncGenerator<TextStreamUpdate> {
for await (const textStreamUpdate of iterator) {
if (textStreamUpdate.done) {
yield textStreamUpdate;
return;
}
let content = textStreamUpdate.value;
if (content.length < 5) {
yield { done: false, value: content };
continue;
}
while (content != '') {
const chunkSize = Math.min(Math.floor(Math.random() * 3) + 1, content.length);
const chunk = content.slice(0, chunkSize);
yield { done: false, value: chunk };
await sleep(5);
content = content.slice(chunkSize);
}
}
}
const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
<script lang="ts"> <script lang="ts">
import { downloadDatabase } from '$lib/apis/utils'; import { downloadDatabase } from '$lib/apis/utils';
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { config } from '$lib/stores';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
...@@ -24,32 +25,34 @@ ...@@ -24,32 +25,34 @@
<div class=" flex w-full justify-between"> <div class=" flex w-full justify-between">
<!-- <div class=" self-center text-xs font-medium">{$i18n.t('Allow Chat Deletion')}</div> --> <!-- <div class=" self-center text-xs font-medium">{$i18n.t('Allow Chat Deletion')}</div> -->
<button {#if $config?.admin_export_enabled ?? true}
class=" flex rounded-md py-1.5 px-3 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition" <button
type="button" class=" flex rounded-md py-1.5 px-3 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition"
on:click={() => { type="button"
// exportAllUserChats(); on:click={() => {
// exportAllUserChats();
downloadDatabase(localStorage.token); downloadDatabase(localStorage.token);
}} }}
> >
<div class=" self-center mr-3"> <div class=" self-center mr-3">
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16" viewBox="0 0 16 16"
fill="currentColor" fill="currentColor"
class="w-4 h-4" class="w-4 h-4"
> >
<path d="M2 3a1 1 0 0 1 1-1h10a1 1 0 0 1 1 1v1a1 1 0 0 1-1 1H3a1 1 0 0 1-1-1V3Z" /> <path d="M2 3a1 1 0 0 1 1-1h10a1 1 0 0 1 1 1v1a1 1 0 0 1-1 1H3a1 1 0 0 1-1-1V3Z" />
<path <path
fill-rule="evenodd" fill-rule="evenodd"
d="M13 6H3v6a2 2 0 0 0 2 2h6a2 2 0 0 0 2-2V6ZM8.75 7.75a.75.75 0 0 0-1.5 0v2.69L6.03 9.22a.75.75 0 0 0-1.06 1.06l2.5 2.5a.75.75 0 0 0 1.06 0l2.5-2.5a.75.75 0 1 0-1.06-1.06l-1.22 1.22V7.75Z" d="M13 6H3v6a2 2 0 0 0 2 2h6a2 2 0 0 0 2-2V6ZM8.75 7.75a.75.75 0 0 0-1.5 0v2.69L6.03 9.22a.75.75 0 0 0-1.06 1.06l2.5 2.5a.75.75 0 0 0 1.06 0l2.5-2.5a.75.75 0 1 0-1.06-1.06l-1.22 1.22V7.75Z"
clip-rule="evenodd" clip-rule="evenodd"
/> />
</svg> </svg>
</div> </div>
<div class=" self-center text-sm font-medium">{$i18n.t('Download Database')}</div> <div class=" self-center text-sm font-medium">{$i18n.t('Download Database')}</div>
</button> </button>
{/if}
</div> </div>
</div> </div>
</div> </div>
......
...@@ -75,14 +75,16 @@ ...@@ -75,14 +75,16 @@
}; };
const updateConfigHandler = async () => { const updateConfigHandler = async () => {
const res = await updateAudioConfig(localStorage.token, { if (TTSEngine === 'openai') {
url: OpenAIUrl, const res = await updateAudioConfig(localStorage.token, {
key: OpenAIKey url: OpenAIUrl,
}); key: OpenAIKey
});
if (res) { if (res) {
OpenAIUrl = res.OPENAI_API_BASE_URL; OpenAIUrl = res.OPENAI_API_BASE_URL;
OpenAIKey = res.OPENAI_API_KEY; OpenAIKey = res.OPENAI_API_KEY;
}
} }
}; };
......
...@@ -301,7 +301,7 @@ ...@@ -301,7 +301,7 @@
</button> </button>
{/if} {/if}
{#if $user?.role === 'admin'} {#if $user?.role === 'admin' && ($config?.admin_export_enabled ?? true)}
<hr class=" dark:border-gray-700" /> <hr class=" dark:border-gray-700" />
<button <button
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
updateImageSize, updateImageSize,
getImageSteps, getImageSteps,
updateImageSteps, updateImageSteps,
getOpenAIKey, getOpenAIConfig,
updateOpenAIKey updateOpenAIConfig
} from '$lib/apis/images'; } from '$lib/apis/images';
import { getBackendConfig } from '$lib/apis'; import { getBackendConfig } from '$lib/apis';
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
let AUTOMATIC1111_BASE_URL = ''; let AUTOMATIC1111_BASE_URL = '';
let COMFYUI_BASE_URL = ''; let COMFYUI_BASE_URL = '';
let OPENAI_API_BASE_URL = '';
let OPENAI_API_KEY = ''; let OPENAI_API_KEY = '';
let selectedModel = ''; let selectedModel = '';
...@@ -131,7 +132,10 @@ ...@@ -131,7 +132,10 @@
AUTOMATIC1111_BASE_URL = URLS.AUTOMATIC1111_BASE_URL; AUTOMATIC1111_BASE_URL = URLS.AUTOMATIC1111_BASE_URL;
COMFYUI_BASE_URL = URLS.COMFYUI_BASE_URL; COMFYUI_BASE_URL = URLS.COMFYUI_BASE_URL;
OPENAI_API_KEY = await getOpenAIKey(localStorage.token); const config = await getOpenAIConfig(localStorage.token);
OPENAI_API_KEY = config.OPENAI_API_KEY;
OPENAI_API_BASE_URL = config.OPENAI_API_BASE_URL;
imageSize = await getImageSize(localStorage.token); imageSize = await getImageSize(localStorage.token);
steps = await getImageSteps(localStorage.token); steps = await getImageSteps(localStorage.token);
...@@ -149,7 +153,7 @@ ...@@ -149,7 +153,7 @@
loading = true; loading = true;
if (imageGenerationEngine === 'openai') { if (imageGenerationEngine === 'openai') {
await updateOpenAIKey(localStorage.token, OPENAI_API_KEY); await updateOpenAIConfig(localStorage.token, OPENAI_API_BASE_URL, OPENAI_API_KEY);
} }
await updateDefaultImageGenerationModel(localStorage.token, selectedModel); await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
...@@ -300,13 +304,22 @@ ...@@ -300,13 +304,22 @@
</button> </button>
</div> </div>
{:else if imageGenerationEngine === 'openai'} {:else if imageGenerationEngine === 'openai'}
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('OpenAI API Key')}</div> <div>
<div class="flex w-full"> <div class=" mb-1.5 text-sm font-medium">{$i18n.t('OpenAI API Config')}</div>
<div class="flex-1 mr-2">
<div class="flex gap-2 mb-1">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={OPENAI_API_BASE_URL}
required
/>
<input <input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter API Key')} placeholder={$i18n.t('API Key')}
bind:value={OPENAI_API_KEY} bind:value={OPENAI_API_KEY}
required
/> />
</div> </div>
</div> </div>
...@@ -319,19 +332,39 @@ ...@@ -319,19 +332,39 @@
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('Set Default Model')}</div> <div class=" mb-2.5 text-sm font-medium">{$i18n.t('Set Default Model')}</div>
<div class="flex w-full"> <div class="flex w-full">
<div class="flex-1 mr-2"> <div class="flex-1 mr-2">
<select {#if imageGenerationEngine === 'openai' && !OPENAI_API_BASE_URL.includes('https://api.openai.com')}
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" <div class="flex w-full">
bind:value={selectedModel} <div class="flex-1">
placeholder={$i18n.t('Select a model')} <input
required list="model-list"
> class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
{#if !selectedModel} bind:value={selectedModel}
<option value="" disabled selected>{$i18n.t('Select a model')}</option> placeholder="Select a model"
{/if} />
{#each models ?? [] as model}
<option value={model.id} class="bg-gray-100 dark:bg-gray-700">{model.name}</option> <datalist id="model-list">
{/each} {#each models ?? [] as model}
</select> <option value={model.id}>{model.name}</option>
{/each}
</datalist>
</div>
</div>
{:else}
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedModel}
placeholder={$i18n.t('Select a model')}
required
>
{#if !selectedModel}
<option value="" disabled selected>{$i18n.t('Select a model')}</option>
{/if}
{#each models ?? [] as model}
<option value={model.id} class="bg-gray-100 dark:bg-gray-700">{model.name}</option
>
{/each}
</select>
{/if}
</div> </div>
</div> </div>
</div> </div>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment