Commit 4ff17acc authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

Merge remote-tracking branch 'upstream/dev' into feat/oauth

parents f49d814d 9928114c
......@@ -73,7 +73,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
############################
@router.get("/name/{name}", response_model=Optional[DocumentResponse])
@router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
......@@ -105,7 +105,7 @@ class TagDocumentForm(BaseModel):
tags: List[dict]
@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
......@@ -128,7 +128,7 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
############################
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
@router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
):
......@@ -152,7 +152,7 @@ async def update_doc_by_name(
############################
@router.delete("/name/{name}/delete", response_model=bool)
@router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result
......@@ -44,6 +44,10 @@ class AddMemoryForm(BaseModel):
content: str
class MemoryUpdateModel(BaseModel):
content: Optional[str] = None
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
......@@ -62,6 +66,34 @@ async def add_memory(
return memory
@router.post("/{memory_id}/update", response_model=Optional[MemoryModel])
async def update_memory_by_id(
memory_id: str,
request: Request,
form_data: MemoryUpdateModel,
user=Depends(get_verified_user),
):
memory = Memories.update_memory_by_id(memory_id, form_data.content)
if memory is None:
raise HTTPException(status_code=404, detail="Memory not found")
if form_data.content is not None:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
documents=[form_data.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[
{"created_at": memory.created_at, "updated_at": memory.updated_at}
],
)
return memory
############################
# QueryMemory
############################
......
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from config import DATA_DIR
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()
############################
# GetToolkits
############################
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# ExportToolKits
############################
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# CreateNewToolKit
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id)
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetToolkitById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolkitById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id)
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
specs = get_tools_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteToolkitById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
del TOOLS[id]
# delete the toolkit file
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
os.remove(toolkit_path)
return result
......@@ -7,6 +7,8 @@ from pydantic import BaseModel
from fpdf import FPDF
import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
......@@ -26,6 +28,21 @@ async def get_gravatar(
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
class MarkdownForm(BaseModel):
md: str
......
from importlib import util
import os
from config import TOOLS_DIR
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {toolkit_id}")
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e
......@@ -435,7 +435,11 @@ STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png"
if frontend_favicon.exists():
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
......@@ -493,6 +497,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# LITELLM_CONFIG
####################################
......@@ -542,6 +554,7 @@ OLLAMA_API_BASE_URL = os.environ.get(
)
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
AIOHTTP_CLIENT_TIMEOUT = int(os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300"))
K8S_FLAG = os.environ.get("K8S_FLAG", "")
USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
......@@ -744,6 +757,78 @@ ADMIN_EMAIL = PersistentConfig(
)
####################################
# TASKS
####################################
TASK_MODEL = PersistentConfig(
"TASK_MODEL",
"task.model.default",
os.environ.get("TASK_MODEL", ""),
)
TASK_MODEL_EXTERNAL = PersistentConfig(
"TASK_MODEL_EXTERNAL",
"task.model.external",
os.environ.get("TASK_MODEL_EXTERNAL", ""),
)
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"task.title.prompt_template",
os.environ.get(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"""Here is the query:
{{prompt:middletruncate:8000}}
Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights""",
),
)
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"task.search.prompt_template",
os.environ.get(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
Question:
{{prompt:end:4000}}""",
),
)
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold",
int(
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
)
),
)
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",
os.environ.get(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"""Tools: {{TOOLS}}
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
),
)
####################################
# WEBUI_SECRET_KEY
####################################
......@@ -991,6 +1076,17 @@ SERPER_API_KEY = PersistentConfig(
os.getenv("SERPER_API_KEY", ""),
)
SERPLY_API_KEY = PersistentConfig(
"SERPLY_API_KEY",
"rag.web.search.serply_api_key",
os.getenv("SERPLY_API_KEY", ""),
)
TAVILY_API_KEY = PersistentConfig(
"TAVILY_API_KEY",
"rag.web.search.tavily_api_key",
os.getenv("TAVILY_API_KEY", ""),
)
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",
......@@ -1072,25 +1168,59 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_OPENAI_API_BASE_URL",
"audio.openai.api_base_url",
os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_OPENAI_API_KEY = PersistentConfig(
"AUDIO_OPENAI_API_KEY",
"audio.openai.api_key",
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_OPENAI_API_MODEL = PersistentConfig(
"AUDIO_OPENAI_API_MODEL",
"audio.openai.api_model",
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
)
AUDIO_OPENAI_API_VOICE = PersistentConfig(
"AUDIO_OPENAI_API_VOICE",
"audio.openai.api_voice",
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
"audio.stt.openai.api_base_url",
os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_STT_OPENAI_API_KEY = PersistentConfig(
"AUDIO_STT_OPENAI_API_KEY",
"audio.stt.openai.api_key",
os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_STT_ENGINE = PersistentConfig(
"AUDIO_STT_ENGINE",
"audio.stt.engine",
os.getenv("AUDIO_STT_ENGINE", ""),
)
AUDIO_STT_MODEL = PersistentConfig(
"AUDIO_STT_MODEL",
"audio.stt.model",
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
)
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_TTS_OPENAI_API_BASE_URL",
"audio.tts.openai.api_base_url",
os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_TTS_OPENAI_API_KEY = PersistentConfig(
"AUDIO_TTS_OPENAI_API_KEY",
"audio.tts.openai.api_key",
os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_TTS_ENGINE = PersistentConfig(
"AUDIO_TTS_ENGINE",
"audio.tts.engine",
os.getenv("AUDIO_TTS_ENGINE", ""),
)
AUDIO_TTS_MODEL = PersistentConfig(
"AUDIO_TTS_MODEL",
"audio.tts.model",
os.getenv("AUDIO_TTS_MODEL", "tts-1"),
)
AUDIO_TTS_VOICE = PersistentConfig(
"AUDIO_TTS_VOICE",
"audio.tts.voice",
os.getenv("AUDIO_TTS_VOICE", "alloy"),
)
......
......@@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum):
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
......
......@@ -13,8 +13,12 @@ import logging
import aiohttp
import requests
import mimetypes
import shutil
import os
import inspect
import asyncio
from fastapi import FastAPI, Request, Depends, status
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi import HTTPException
......@@ -27,21 +31,33 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
from apps.ollama.main import (
app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion,
)
from apps.openai.main import (
app as openai_app,
get_all_models as get_openai_models,
generate_chat_completion as generate_openai_chat_completion,
)
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app
import asyncio
from pydantic import BaseModel
from typing import List, Optional
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
from apps.webui.models.users import Users
from apps.webui.utils import load_toolkit_module_by_id
from utils.misc import parse_duration
from utils.utils import (
get_admin_user,
......@@ -51,7 +67,14 @@ from utils.utils import (
get_password_hash,
create_token,
)
from apps.rag.utils import rag_messages
from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from apps.rag.utils import get_rag_context, rag_template
from config import (
CONFIG_DATA,
......@@ -72,14 +95,20 @@ from config import (
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
AppConfig,
WEBUI_BUILD_HASH,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
......@@ -134,27 +163,133 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
app.state.MODELS = {}
origins = ["*"]
# Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next):
# response: Response = await call_next(request)
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
# return response
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
user_message = get_last_user_message(messages)
prompt = (
"History:\n"
+ "\n".join(
[
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
]
)
+ f"\nQuery: {user_message}"
)
print(prompt)
payload = {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
}
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
model = app.state.MODELS[task_model_id]
# app.add_middleware(SecurityHeadersMiddleware)
response = None
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
content = None
class RAGMiddleware(BaseHTTPMiddleware):
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
# Parse the function response
if content is not None:
print(f"content: {content}")
result = json.loads(content)
print(result)
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
# Check if '__user__' is a parameter of the function
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
function_result = function(
**{
**result["parameters"],
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
}
)
else:
# Call the function without modifying the parameters
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
except Exception as e:
print(f"Error: {e}")
return None
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
......@@ -171,35 +306,98 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
# Remove the citations from the body
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
# Set the task model
task_model_id = data["model"]
if task_model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
try:
response = await get_function_call_response(
messages=data["messages"],
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
if response:
context += ("\n" if context != "" else "") + response
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
# If docs field is present, generate RAG completions
if "docs" in data:
data = {**data}
data["messages"], citations = rag_messages(
rag_context, citations = get_rag_context(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.config.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
del data["docs"]
log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
......@@ -242,7 +440,80 @@ class RAGMiddleware(BaseHTTPMiddleware):
yield data
app.add_middleware(RAGMiddleware)
app.add_middleware(ChatCompletionMiddleware)
def filter_pipeline(payload, user):
user = {"id": user.id, "name": user.name, "role": user.role}
model_id = payload["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
except:
pass
if "detail" in res:
raise Exception(r.status_code, res["detail"])
else:
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
if "task" in payload:
del payload["task"]
return payload
class PipelineMiddleware(BaseHTTPMiddleware):
......@@ -260,85 +531,17 @@ class PipelineMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
user = None
if len(sorted_filters) > 0:
try:
user = get_current_user(
get_http_authorization_cred(
request.headers.get("Authorization")
)
)
user = {"id": user.id, "name": user.name, "role": user.role}
except:
pass
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in data:
del data["chat_id"]
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
if "title" in data:
del data["title"]
try:
data = filter_pipeline(data, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
......@@ -499,6 +702,302 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models}
@app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)):
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@app.post("/api/task/config/update")
async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
app.state.config.TASK_MODEL = form_data.TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@app.post("/api/task/title/completions")
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
print("generate_title")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 50,
"chat_id": form_data.get("chat_id", None),
"title": True,
}
print(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/query/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query")
if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
)
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
content = search_query_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
"task": True,
}
print(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/emoji/completions")
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
print("generate_emoji")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
Message: """{{prompt}}"""
'''
content = title_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 4,
"chat_id": form_data.get("chat_id", None),
"task": True,
}
print(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user
)
return context
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**form_data), user=user
)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
......@@ -591,6 +1090,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
}
@app.post("/api/pipelines/upload")
async def upload_pipeline(
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
):
print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file
if not file.filename.endswith(".py"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.",
)
upload_folder = f"{CACHE_DIR}/pipelines"
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
try:
# Save the uploaded file
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
with open(file_path, "rb") as f:
files = {"file": f}
r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)
finally:
# Ensure the file is deleted after the upload is completed or on failure
if os.path.exists(file_path):
os.remove(file_path)
class AddPipelineForm(BaseModel):
url: str
urlIdx: int
......@@ -857,6 +1413,15 @@ async def get_app_config():
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
},
"audio": {
"tts": {
"engine": audio_app.state.config.TTS_ENGINE,
"voice": audio_app.state.config.TTS_VOICE,
},
"stt": {
"engine": audio_app.state.config.STT_ENGINE,
},
},
"oauth": {
"providers": {
name: config.get("name", name)
......@@ -925,7 +1490,7 @@ async def get_app_changelog():
@app.get("/api/version/updates")
async def get_app_latest_release_version():
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response:
......
......@@ -57,4 +57,8 @@ authlib==1.3.0
black==24.4.2
langfuse==2.33.0
youtube-transcript-api==0.6.2
pytube==15.0.0
\ No newline at end of file
pytube==15.0.0
extract_msg
pydub
duckduckgo-search~=6.1.5
\ No newline at end of file
......@@ -20,12 +20,12 @@ if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
WEBUI_SECRET_KEY=$(cat "$KEY_FILE")
fi
if [ "$USE_OLLAMA_DOCKER" = "true" ]; then
if [[ "${USE_OLLAMA_DOCKER,,}" == "true" ]]; then
echo "USE_OLLAMA is set to true, starting ollama serve."
ollama serve &
fi
if [ "$USE_CUDA_DOCKER" = "true" ]; then
if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then
echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
fi
......
......@@ -8,6 +8,7 @@ cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key"
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
......@@ -29,4 +30,4 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
:: Execute uvicorn
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
uvicorn main:app --host 0.0.0.0 --port "%PORT%" --forwarded-allow-ips '*'
uvicorn main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*'
......@@ -3,7 +3,48 @@ import hashlib
import json
import re
from datetime import timedelta
from typing import Optional
from typing import Optional, List
def get_last_user_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "user":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def get_last_assistant_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "assistant":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list
or updates the existing system message at the beginning.
:param msg: The message to be added or appended.
:param messages: The list of message dictionaries.
:return: The updated list of message dictionaries.
"""
if messages and messages[0].get("role") == "system":
messages[0]["content"] += f"{content}\n{messages[0]['content']}"
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})
return messages
def get_gravatar_url(email):
......@@ -193,8 +234,14 @@ def parse_ollama_modelfile(model_text):
system_desc_match = re.search(
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
system_desc_match_single = re.search(
r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
)
if system_desc_match:
data["params"]["system"] = system_desc_match.group(1).strip()
elif system_desc_match_single:
data["params"]["system"] = system_desc_match_single.group(1).strip()
# Parse messages
messages = []
......
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
def get_model_id_from_custom_model_id(id: str):
model = Models.get_model_by_id(id)
if model:
return model.id
else:
return id
import re
import math
from datetime import datetime
from typing import Optional
def prompt_template(
template: str, user_name: str = None, current_location: str = None
) -> str:
# Get the current date
current_date = datetime.now()
# Format the date to YYYY-MM-DD
formatted_date = current_date.strftime("%Y-%m-%d")
# Replace {{CURRENT_DATE}} in the template with the formatted date
template = template.replace("{{CURRENT_DATE}}", formatted_date)
if user_name:
# Replace {{USER_NAME}} in the template with the user's name
template = template.replace("{{USER_NAME}}", user_name)
if current_location:
# Replace {{CURRENT_LOCATION}} in the template with the current location
template = template.replace("{{CURRENT_LOCATION}}", current_location)
return template
def title_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
elif start_length is not None:
return prompt[: int(start_length)]
elif end_length is not None:
return prompt[-int(end_length) :]
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
replacement_function,
template,
)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
if user
else {}
),
)
return template
def search_query_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
elif start_length is not None:
return prompt[: int(start_length)]
elif end_length is not None:
return prompt[-int(end_length) :]
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
replacement_function,
template,
)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
if user
else {}
),
)
return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
return template
import inspect
from typing import get_type_hints, List, Dict, Any
def doc_to_dict(docstring):
lines = docstring.split("\n")
description = lines[1].strip()
param_dict = {}
for line in lines:
if ":param" in line:
line = line.replace(":param", "").strip()
param, desc = line.split(":", 1)
param_dict[param.strip()] = desc.strip()
ret_dict = {"description": description, "params": param_dict}
return ret_dict
def get_tools_specs(tools) -> List[dict]:
function_list = [
{"name": func, "function": getattr(tools, func)}
for func in dir(tools)
if callable(getattr(tools, func)) and not func.startswith("__")
]
specs = []
for function_item in function_list:
function_name = function_item["name"]
function = function_item["function"]
function_doc = doc_to_dict(function.__doc__ or function_name)
specs.append(
{
"name": function_name,
# TODO: multi-line desc?
"description": function_doc.get("description", function_name),
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_annotation.__name__.lower(),
**(
{
"enum": (
str(param_annotation.__args__)
if hasattr(param_annotation, "__args__")
else None
)
}
if hasattr(param_annotation, "__args__")
else {}
),
"description": function_doc.get("params", {}).get(
param_name, param_name
),
}
for param_name, param_annotation in get_type_hints(
function
).items()
if param_name != "return" and param_name != "__user__"
},
"required": [
name
for name, param in inspect.signature(
function
).parameters.items()
if param.default is param.empty
],
},
}
)
return specs
......@@ -28,19 +28,6 @@ describe('Settings', () => {
});
});
context('Connections', () => {
it('user can open the Connections modal and hit save', () => {
cy.get('button').contains('Connections').click();
cy.get('button').contains('Save').click();
});
});
context('Models', () => {
it('user can open the Models modal', () => {
cy.get('button').contains('Models').click();
});
});
context('Interface', () => {
it('user can open the Interface modal and hit save', () => {
cy.get('button').contains('Interface').click();
......@@ -55,14 +42,6 @@ describe('Settings', () => {
});
});
context('Images', () => {
it('user can open the Images modal and hit save', () => {
cy.get('button').contains('Images').click();
// Currently fails because the backend requires a valid URL
// cy.get('button').contains('Save').click();
});
});
context('Chats', () => {
it('user can open the Chats modal', () => {
cy.get('button').contains('Chats').click();
......
......@@ -41,7 +41,7 @@ Looking to contribute? Great! Here's how you can help:
We welcome pull requests. Before submitting one, please:
1. Discuss your idea or issue in the [issues section](https://github.com/open-webui/open-webui/issues).
1. Open a discussion regarding your ideas [here](https://github.com/open-webui/open-webui/discussions/new/choose).
2. Follow the project's coding standards and include tests for new features.
3. Update documentation as necessary.
4. Write clear, descriptive commit messages.
......
{
"name": "open-webui",
"version": "0.2.5",
"version": "0.3.4",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "open-webui",
"version": "0.2.5",
"version": "0.3.4",
"dependencies": {
"@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6",
"@codemirror/theme-one-dark": "^6.1.2",
"@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^1.3.1",
"async": "^3.2.5",
"bits-ui": "^0.19.7",
"codemirror": "^6.0.1",
"dayjs": "^1.11.10",
"eventsource-parser": "^1.1.2",
"file-saver": "^2.0.5",
......@@ -108,6 +112,119 @@
"resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-6.0.4.tgz",
"integrity": "sha512-s3jaWicZd0pkP0jf5ysyHUI/RE7MHos6qlToFcGWXVp+ykHOy77OUMrfbgJ9it2C5bow7OIQwYYaHjk9XlBQ2A=="
},
"node_modules/@codemirror/autocomplete": {
"version": "6.16.2",
"resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.16.2.tgz",
"integrity": "sha512-MjfDrHy0gHKlPWsvSsikhO1+BOh+eBHNgfH1OXs1+DAf30IonQldgMM3kxLDTG9ktE7kDLaA1j/l7KMPA4KNfw==",
"dependencies": {
"@codemirror/language": "^6.0.0",
"@codemirror/state": "^6.0.0",
"@codemirror/view": "^6.17.0",
"@lezer/common": "^1.0.0"
},
"peerDependencies": {
"@codemirror/language": "^6.0.0",
"@codemirror/state": "^6.0.0",
"@codemirror/view": "^6.0.0",
"@lezer/common": "^1.0.0"
}
},
"node_modules/@codemirror/commands": {
"version": "6.6.0",
"resolved": "https://registry.npmjs.org/@codemirror/commands/-/commands-6.6.0.tgz",
"integrity": "sha512-qnY+b7j1UNcTS31Eenuc/5YJB6gQOzkUoNmJQc0rznwqSRpeaWWpjkWy2C/MPTcePpsKJEM26hXrOXl1+nceXg==",
"dependencies": {
"@codemirror/language": "^6.0.0",
"@codemirror/state": "^6.4.0",
"@codemirror/view": "^6.27.0",
"@lezer/common": "^1.1.0"
}
},
"node_modules/@codemirror/lang-javascript": {
"version": "6.2.2",
"resolved": "https://registry.npmjs.org/@codemirror/lang-javascript/-/lang-javascript-6.2.2.tgz",
"integrity": "sha512-VGQfY+FCc285AhWuwjYxQyUQcYurWlxdKYT4bqwr3Twnd5wP5WSeu52t4tvvuWmljT4EmgEgZCqSieokhtY8hg==",
"dependencies": {
"@codemirror/autocomplete": "^6.0.0",
"@codemirror/language": "^6.6.0",
"@codemirror/lint": "^6.0.0",
"@codemirror/state": "^6.0.0",
"@codemirror/view": "^6.17.0",
"@lezer/common": "^1.0.0",
"@lezer/javascript": "^1.0.0"
}
},
"node_modules/@codemirror/lang-python": {
"version": "6.1.6",
"resolved": "https://registry.npmjs.org/@codemirror/lang-python/-/lang-python-6.1.6.tgz",
"integrity": "sha512-ai+01WfZhWqM92UqjnvorkxosZ2aq2u28kHvr+N3gu012XqY2CThD67JPMHnGceRfXPDBmn1HnyqowdpF57bNg==",
"dependencies": {
"@codemirror/autocomplete": "^6.3.2",
"@codemirror/language": "^6.8.0",
"@codemirror/state": "^6.0.0",
"@lezer/common": "^1.2.1",
"@lezer/python": "^1.1.4"
}
},
"node_modules/@codemirror/language": {
"version": "6.10.2",
"resolved": "https://registry.npmjs.org/@codemirror/language/-/language-6.10.2.tgz",
"integrity": "sha512-kgbTYTo0Au6dCSc/TFy7fK3fpJmgHDv1sG1KNQKJXVi+xBTEeBPY/M30YXiU6mMXeH+YIDLsbrT4ZwNRdtF+SA==",
"dependencies": {
"@codemirror/state": "^6.0.0",
"@codemirror/view": "^6.23.0",
"@lezer/common": "^1.1.0",
"@lezer/highlight": "^1.0.0",
"@lezer/lr": "^1.0.0",
"style-mod": "^4.0.0"
}
},
"node_modules/@codemirror/lint": {
"version": "6.8.0",
"resolved": "https://registry.npmjs.org/@codemirror/lint/-/lint-6.8.0.tgz",
"integrity": "sha512-lsFofvaw0lnPRJlQylNsC4IRt/1lI4OD/yYslrSGVndOJfStc58v+8p9dgGiD90ktOfL7OhBWns1ZETYgz0EJA==",
"dependencies": {
"@codemirror/state": "^6.0.0",
"@codemirror/view": "^6.0.0",
"crelt": "^1.0.5"
}
},
"node_modules/@codemirror/search": {
"version": "6.5.6",
"resolved": "https://registry.npmjs.org/@codemirror/search/-/search-6.5.6.tgz",
"integrity": "sha512-rpMgcsh7o0GuCDUXKPvww+muLA1pDJaFrpq/CCHtpQJYz8xopu4D1hPcKRoDD0YlF8gZaqTNIRa4VRBWyhyy7Q==",
"dependencies": {
"@codemirror/state": "^6.0.0",
"@codemirror/view": "^6.0.0",
"crelt": "^1.0.5"
}
},
"node_modules/@codemirror/state": {
"version": "6.4.1",
"resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.4.1.tgz",
"integrity": "sha512-QkEyUiLhsJoZkbumGZlswmAhA7CBU02Wrz7zvH4SrcifbsqwlXShVXg65f3v/ts57W3dqyamEriMhij1Z3Zz4A=="
},
"node_modules/@codemirror/theme-one-dark": {
"version": "6.1.2",
"resolved": "https://registry.npmjs.org/@codemirror/theme-one-dark/-/theme-one-dark-6.1.2.tgz",
"integrity": "sha512-F+sH0X16j/qFLMAfbciKTxVOwkdAS336b7AXTKOZhy8BR3eH/RelsnLgLFINrpST63mmN2OuwUt0W2ndUgYwUA==",
"dependencies": {
"@codemirror/language": "^6.0.0",
"@codemirror/state": "^6.0.0",
"@codemirror/view": "^6.0.0",
"@lezer/highlight": "^1.0.0"
}
},
"node_modules/@codemirror/view": {
"version": "6.28.0",
"resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.28.0.tgz",
"integrity": "sha512-fo7CelaUDKWIyemw4b+J57cWuRkOu4SWCCPfNDkPvfWkGjM9D5racHQXr4EQeYCD6zEBIBxGCeaKkQo+ysl0gA==",
"dependencies": {
"@codemirror/state": "^6.4.0",
"style-mod": "^4.1.0",
"w3c-keyname": "^2.2.4"
}
},
"node_modules/@colors/colors": {
"version": "1.5.0",
"resolved": "https://registry.npmjs.org/@colors/colors/-/colors-1.5.0.tgz",
......@@ -825,6 +942,47 @@
"@jridgewell/sourcemap-codec": "^1.4.14"
}
},
"node_modules/@lezer/common": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/@lezer/common/-/common-1.2.1.tgz",
"integrity": "sha512-yemX0ZD2xS/73llMZIK6KplkjIjf2EvAHcinDi/TfJ9hS25G0388+ClHt6/3but0oOxinTcQHJLDXh6w1crzFQ=="
},
"node_modules/@lezer/highlight": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/@lezer/highlight/-/highlight-1.2.0.tgz",
"integrity": "sha512-WrS5Mw51sGrpqjlh3d4/fOwpEV2Hd3YOkp9DBt4k8XZQcoTHZFB7sx030A6OcahF4J1nDQAa3jXlTVVYH50IFA==",
"dependencies": {
"@lezer/common": "^1.0.0"
}
},
"node_modules/@lezer/javascript": {
"version": "1.4.16",
"resolved": "https://registry.npmjs.org/@lezer/javascript/-/javascript-1.4.16.tgz",
"integrity": "sha512-84UXR3N7s11MPQHWgMnjb9571fr19MmXnr5zTv2XX0gHXXUvW3uPJ8GCjKrfTXmSdfktjRK0ayKklw+A13rk4g==",
"dependencies": {
"@lezer/common": "^1.2.0",
"@lezer/highlight": "^1.1.3",
"@lezer/lr": "^1.3.0"
}
},
"node_modules/@lezer/lr": {
"version": "1.4.1",
"resolved": "https://registry.npmjs.org/@lezer/lr/-/lr-1.4.1.tgz",
"integrity": "sha512-CHsKq8DMKBf9b3yXPDIU4DbH+ZJd/sJdYOW2llbW/HudP5u0VS6Bfq1hLYfgU7uAYGFIyGGQIsSOXGPEErZiJw==",
"dependencies": {
"@lezer/common": "^1.0.0"
}
},
"node_modules/@lezer/python": {
"version": "1.1.14",
"resolved": "https://registry.npmjs.org/@lezer/python/-/python-1.1.14.tgz",
"integrity": "sha512-ykDOb2Ti24n76PJsSa4ZoDF0zH12BSw1LGfQXCYJhJyOGiFTfGaX0Du66Ze72R+u/P35U+O6I9m8TFXov1JzsA==",
"dependencies": {
"@lezer/common": "^1.2.0",
"@lezer/highlight": "^1.0.0",
"@lezer/lr": "^1.0.0"
}
},
"node_modules/@melt-ui/svelte": {
"version": "0.76.0",
"resolved": "https://registry.npmjs.org/@melt-ui/svelte/-/svelte-0.76.0.tgz",
......@@ -2224,12 +2382,12 @@
}
},
"node_modules/braces": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz",
"integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==",
"version": "3.0.3",
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz",
"integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==",
"dev": true,
"dependencies": {
"fill-range": "^7.0.1"
"fill-range": "^7.1.1"
},
"engines": {
"node": ">=8"
......@@ -2769,6 +2927,20 @@
"plain-tag": "^0.1.3"
}
},
"node_modules/codemirror": {
"version": "6.0.1",
"resolved": "https://registry.npmjs.org/codemirror/-/codemirror-6.0.1.tgz",
"integrity": "sha512-J8j+nZ+CdWmIeFIGXEFbFPtpiYacFMDR8GlHK3IyHQJMCaVRfGx9NT+Hxivv1ckLWPvNdZqndbr/7lVhrf/Svg==",
"dependencies": {
"@codemirror/autocomplete": "^6.0.0",
"@codemirror/commands": "^6.0.0",
"@codemirror/language": "^6.0.0",
"@codemirror/lint": "^6.0.0",
"@codemirror/search": "^6.0.0",
"@codemirror/state": "^6.0.0",
"@codemirror/view": "^6.0.0"
}
},
"node_modules/coincident": {
"version": "1.2.3",
"resolved": "https://registry.npmjs.org/coincident/-/coincident-1.2.3.tgz",
......@@ -2891,6 +3063,11 @@
"layout-base": "^1.0.0"
}
},
"node_modules/crelt": {
"version": "1.0.6",
"resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz",
"integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g=="
},
"node_modules/cross-spawn": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
......@@ -4429,9 +4606,9 @@
"integrity": "sha512-P9bmyZ3h/PRG+Nzga+rbdI4OEpNDzAVyy74uVO9ATgzLK6VtAsYybF/+TOCvrc0MO793d6+42lLyZTw7/ArVzA=="
},
"node_modules/fill-range": {
"version": "7.0.1",
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz",
"integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==",
"version": "7.1.1",
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz",
"integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==",
"dev": true,
"dependencies": {
"to-regex-range": "^5.0.1"
......@@ -8278,6 +8455,11 @@
"url": "https://github.com/sponsors/antfu"
}
},
"node_modules/style-mod": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/style-mod/-/style-mod-4.1.2.tgz",
"integrity": "sha512-wnD1HyVqpJUI2+eKZ+eo1UwghftP6yuFheBqqe+bWCotBjC2K1YnteJILRMs3SM4V/0dLEW1SC27MWP5y+mwmw=="
},
"node_modules/stylis": {
"version": "4.3.2",
"resolved": "https://registry.npmjs.org/stylis/-/stylis-4.3.2.tgz",
......@@ -10022,6 +10204,11 @@
"he": "^1.2.0"
}
},
"node_modules/w3c-keyname": {
"version": "2.2.8",
"resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz",
"integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ=="
},
"node_modules/walk-sync": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/walk-sync/-/walk-sync-2.2.0.tgz",
......
{
"name": "open-webui",
"version": "0.2.5",
"version": "0.3.4",
"private": true,
"scripts": {
"dev": "npm run pyodide:fetch && vite dev --host",
......@@ -16,7 +16,7 @@
"format:backend": "black . --exclude \".venv/|/venv/\"",
"i18n:parse": "i18next --config i18next-parser.config.ts && prettier --write \"src/lib/i18n/**/*.{js,json}\"",
"cy:open": "cypress open",
"test:frontend": "vitest",
"test:frontend": "vitest --passWithNoTests",
"pyodide:fetch": "node scripts/prepare-pyodide.js"
},
"devDependencies": {
......@@ -48,10 +48,14 @@
},
"type": "module",
"dependencies": {
"@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6",
"@codemirror/theme-one-dark": "^6.1.2",
"@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^1.3.1",
"async": "^3.2.5",
"bits-ui": "^0.19.7",
"codemirror": "^6.0.1",
"dayjs": "^1.11.10",
"eventsource-parser": "^1.1.2",
"file-saver": "^2.0.5",
......
......@@ -26,8 +26,6 @@ dependencies = [
"PyMySQL==1.1.0",
"bcrypt==4.1.3",
"litellm[proxy]==1.37.20",
"boto3==1.34.110",
"argon2-cffi==23.1.0",
......@@ -67,6 +65,10 @@ dependencies = [
"langfuse==2.33.0",
"youtube-transcript-api==0.6.2",
"pytube==15.0.0",
"extract_msg",
"pydub",
"duckduckgo-search~=6.1.5"
]
readme = "README.md"
requires-python = ">= 3.11, < 3.12.0a1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment