Unverified Commit 75d71305 authored by perf3ct's avatar perf3ct
Browse files

Merge remote-tracking branch 'upstream/main' into feature-external-db-reconnect

parents ad32a2ef 162643a4
import socketio
import asyncio
from apps.webui.models.users import Users
from utils.utils import decode_token
sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
@sio.event
async def connect(sid, environ, auth):
user = None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()})
@sio.on("user-join")
async def user_join(sid, data):
print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count")
async def user_count(sid):
await sio.emit("user-count", {"count": len(set(USER_POOL))})
def get_models_in_use():
# Aggregate all models in use
models_in_use = []
for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id)
return models_in_use
@sio.on("usage")
async def usage(sid, data):
model_id = data["model"]
# Cancel previous callback if there is one
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["callback"].cancel()
# Store the new usage data and task
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["sids"].append(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
else:
USAGE_POOL[model_id] = {"sids": [sid]}
# Schedule a task to remove the usage data after TIMEOUT_DURATION
USAGE_POOL[model_id]["callback"] = asyncio.create_task(
remove_after_timeout(sid, model_id)
)
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
async def remove_after_timeout(sid, model_id):
try:
await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"])
USAGE_POOL[model_id]["sids"].remove(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id]
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError:
# Task was cancelled due to new 'usage' event
pass
@sio.event
async def disconnect(sid):
if sid in SESSION_POOL:
user_id = SESSION_POOL[sid]
del SESSION_POOL[sid]
USER_POOL[user_id].remove(sid)
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
await sio.emit("user-count", {"count": len(USER_POOL)})
else:
print(f"Unknown session ID {sid} disconnected")
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Tool(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
content = pw.TextField()
specs = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")
......@@ -6,6 +6,7 @@ from apps.webui.routers import (
users,
chats,
documents,
tools,
models,
prompts,
configs,
......@@ -14,6 +15,8 @@ from apps.webui.routers import (
)
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
ADMIN_EMAIL,
WEBUI_AUTH,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
......@@ -24,8 +27,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
AppConfig,
ENABLE_COMMUNITY_SHARING,
AppConfig,
)
app = FastAPI()
......@@ -36,6 +39,12 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
......@@ -47,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.TOOLS = {}
app.add_middleware(
......@@ -63,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
......
......@@ -298,6 +298,15 @@ class ChatTable:
# .limit(limit).offset(skip)
]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
]
def delete_chat_by_id(self, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))
......
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ToolMeta(BaseModel):
description: Optional[str] = None
class ToolModel(BaseModel):
id: str
user_id: str
name: str
content: str
specs: List[dict]
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool.create(**tool.model_dump())
if result:
return tool
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Tools = ToolsTable(DB)
......@@ -270,72 +270,87 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
############################
# ToggleSignUp
# GetAdminDetails
############################
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.ENABLE_SIGNUP
@router.get("/admin/details")
async def get_admin_details(request: Request, user=Depends(get_current_user)):
if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None
print(admin_email, admin_name)
if admin_email:
admin = Users.get_user_by_email(admin_email)
if admin:
admin_name = admin.name
else:
admin = Users.get_first_user()
if admin:
admin_email = admin.email
admin_name = admin.name
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
return request.app.state.config.ENABLE_SIGNUP
return {
"name": admin_name,
"email": admin_email,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
############################
# Default User Role
# ToggleSignUp
############################
@router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.DEFAULT_USER_ROLE
@router.get("/admin/config")
async def get_admin_config(request: Request, user=Depends(get_admin_user)):
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
}
class UpdateRoleForm(BaseModel):
role: str
class AdminConfig(BaseModel):
SHOW_ADMIN_DETAILS: bool
ENABLE_SIGNUP: bool
DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool
@router.post("/signup/user/role")
async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
@router.post("/admin/config")
async def update_admin_config(
request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
):
if form_data.role in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.role
return request.app.state.config.DEFAULT_USER_ROLE
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
############################
# JWT Expiration
############################
@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel):
duration: str
@router.post("/token/expires/update")
async def update_token_expires_duration(
request: Request,
form_data: UpdateJWTExpiresDurationForm,
user=Depends(get_admin_user),
):
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
# Check if the input string matches the pattern
if re.match(pattern, form_data.duration):
request.app.state.config.JWT_EXPIRES_IN = form_data.duration
return request.app.state.config.JWT_EXPIRES_IN
else:
return request.app.state.config.JWT_EXPIRES_IN
if re.match(pattern, form_data.JWT_EXPIRES_IN):
request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN
request.app.state.config.ENABLE_COMMUNITY_SHARING = (
form_data.ENABLE_COMMUNITY_SHARING
)
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
}
############################
......
......@@ -113,6 +113,19 @@ async def get_user_chats(user=Depends(get_current_user)):
]
############################
# GetArchivedChats
############################
@router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id)
]
############################
# GetAllChatsInDB
############################
......@@ -148,7 +161,7 @@ async def get_archived_session_user_chat_list(
############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse])
@router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id)
......@@ -288,6 +301,32 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
return result
############################
# CloneChat
############################
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat_body = json.loads(chat.chat)
updated_chat = {
**chat_body,
"originalChatId": chat.id,
"branchPointMessageId": chat_body["history"]["currentId"],
"title": f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# ArchiveChat
############################
......
......@@ -73,7 +73,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
############################
@router.get("/name/{name}", response_model=Optional[DocumentResponse])
@router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
......@@ -105,7 +105,7 @@ class TagDocumentForm(BaseModel):
tags: List[dict]
@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
......@@ -128,7 +128,7 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
############################
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
@router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
):
......@@ -152,7 +152,7 @@ async def update_doc_by_name(
############################
@router.delete("/name/{name}/delete", response_model=bool)
@router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from config import DATA_DIR
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()
############################
# GetToolkits
############################
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# ExportToolKits
############################
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# CreateNewToolKit
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id)
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetToolkitById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolkitById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id)
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
specs = get_tools_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteToolkitById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
del TOOLS[id]
# delete the toolkit file
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
os.remove(toolkit_path)
return result
......@@ -19,7 +19,12 @@ from apps.webui.models.users import (
from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats
from utils.utils import get_verified_user, get_password_hash, get_admin_user
from utils.utils import (
get_verified_user,
get_password_hash,
get_current_user,
get_admin_user,
)
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
......
......@@ -7,6 +7,8 @@ from pydantic import BaseModel
from fpdf import FPDF
import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
......@@ -26,6 +28,21 @@ async def get_gravatar(
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
class MarkdownForm(BaseModel):
md: str
......@@ -107,3 +124,12 @@ async def download_db(user=Depends(get_admin_user)):
media_type="application/octet-stream",
filename="webui.db",
)
@router.get("/litellm/config")
async def download_litellm_config_yaml(user=Depends(get_admin_user)):
return FileResponse(
f"{DATA_DIR}/litellm/config.yaml",
media_type="application/octet-stream",
filename="config.yaml",
)
from importlib import util
import os
from config import TOOLS_DIR
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {toolkit_id}")
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e
......@@ -180,6 +180,17 @@ WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except:
......@@ -295,7 +306,11 @@ STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png"
if frontend_favicon.exists():
try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
......@@ -353,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# LITELLM_CONFIG
####################################
......@@ -590,6 +613,92 @@ WEBUI_BANNERS = PersistentConfig(
[BannerModel(**banner) for banner in json.loads("[]")],
)
SHOW_ADMIN_DETAILS = PersistentConfig(
"SHOW_ADMIN_DETAILS",
"auth.admin.show",
os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true",
)
ADMIN_EMAIL = PersistentConfig(
"ADMIN_EMAIL",
"auth.admin.email",
os.environ.get("ADMIN_EMAIL", None),
)
####################################
# TASKS
####################################
TASK_MODEL = PersistentConfig(
"TASK_MODEL",
"task.model.default",
os.environ.get("TASK_MODEL", ""),
)
TASK_MODEL_EXTERNAL = PersistentConfig(
"TASK_MODEL_EXTERNAL",
"task.model.external",
os.environ.get("TASK_MODEL_EXTERNAL", ""),
)
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"task.title.prompt_template",
os.environ.get(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"""Here is the query:
{{prompt:middletruncate:8000}}
Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights""",
),
)
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"task.search.prompt_template",
os.environ.get(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
Question:
{{prompt:end:4000}}""",
),
)
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold",
int(
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
)
),
)
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",
os.environ.get(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"""Tools: {{TOOLS}}
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
),
)
####################################
# WEBUI_SECRET_KEY
####################################
......@@ -672,6 +781,12 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
"RAG_EMBEDDING_OPENAI_BATCH_SIZE",
"rag.embedding_openai_batch_size",
os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", 1),
)
RAG_RERANKING_MODEL = PersistentConfig(
"RAG_RERANKING_MODEL",
"rag.reranking_model",
......@@ -766,28 +881,81 @@ YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
)
SEARXNG_QUERY_URL = os.getenv("SEARXNG_QUERY_URL", "")
GOOGLE_PSE_API_KEY = os.getenv("GOOGLE_PSE_API_KEY", "")
GOOGLE_PSE_ENGINE_ID = os.getenv("GOOGLE_PSE_ENGINE_ID", "")
BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY", "")
SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "")
SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true"
SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
ENABLE_RAG_WEB_SEARCH = PersistentConfig(
"ENABLE_RAG_WEB_SEARCH",
"rag.web.search.enable",
os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
)
RAG_WEB_SEARCH_ENGINE = PersistentConfig(
"RAG_WEB_SEARCH_ENGINE",
"rag.web.search.engine",
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
)
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url",
os.getenv("SEARXNG_QUERY_URL", ""),
)
GOOGLE_PSE_API_KEY = PersistentConfig(
"GOOGLE_PSE_API_KEY",
"rag.web.search.google_pse_api_key",
os.getenv("GOOGLE_PSE_API_KEY", ""),
)
GOOGLE_PSE_ENGINE_ID = PersistentConfig(
"GOOGLE_PSE_ENGINE_ID",
"rag.web.search.google_pse_engine_id",
os.getenv("GOOGLE_PSE_ENGINE_ID", ""),
)
BRAVE_SEARCH_API_KEY = PersistentConfig(
"BRAVE_SEARCH_API_KEY",
"rag.web.search.brave_search_api_key",
os.getenv("BRAVE_SEARCH_API_KEY", ""),
)
SERPSTACK_API_KEY = PersistentConfig(
"SERPSTACK_API_KEY",
"rag.web.search.serpstack_api_key",
os.getenv("SERPSTACK_API_KEY", ""),
)
SERPSTACK_HTTPS = PersistentConfig(
"SERPSTACK_HTTPS",
"rag.web.search.serpstack_https",
os.getenv("SERPSTACK_HTTPS", "True").lower() == "true",
)
SERPER_API_KEY = PersistentConfig(
"SERPER_API_KEY",
"rag.web.search.serper_api_key",
os.getenv("SERPER_API_KEY", ""),
)
SERPLY_API_KEY = PersistentConfig(
"SERPLY_API_KEY",
"rag.web.search.serply_api_key",
os.getenv("SERPLY_API_KEY", ""),
)
RAG_WEB_SEARCH_ENABLED = (
SEARXNG_QUERY_URL != ""
or (GOOGLE_PSE_API_KEY != "" and GOOGLE_PSE_ENGINE_ID != "")
or BRAVE_SEARCH_API_KEY != ""
or SERPSTACK_API_KEY != ""
or SERPER_API_KEY != ""
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",
"rag.web.search.result_count",
int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")),
)
RAG_WEB_SEARCH_RESULT_COUNT = int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3"))
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = int(
os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
"RAG_WEB_SEARCH_CONCURRENT_REQUESTS",
"rag.web.search.concurrent_requests",
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
####################################
# Transcribe
####################################
......@@ -855,25 +1023,59 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_OPENAI_API_BASE_URL",
"audio.openai.api_base_url",
os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
"audio.stt.openai.api_base_url",
os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_OPENAI_API_KEY = PersistentConfig(
"AUDIO_OPENAI_API_KEY",
"audio.openai.api_key",
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
AUDIO_STT_OPENAI_API_KEY = PersistentConfig(
"AUDIO_STT_OPENAI_API_KEY",
"audio.stt.openai.api_key",
os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_OPENAI_API_MODEL = PersistentConfig(
"AUDIO_OPENAI_API_MODEL",
"audio.openai.api_model",
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
AUDIO_STT_ENGINE = PersistentConfig(
"AUDIO_STT_ENGINE",
"audio.stt.engine",
os.getenv("AUDIO_STT_ENGINE", ""),
)
AUDIO_STT_MODEL = PersistentConfig(
"AUDIO_STT_MODEL",
"audio.stt.model",
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
)
AUDIO_OPENAI_API_VOICE = PersistentConfig(
"AUDIO_OPENAI_API_VOICE",
"audio.openai.api_voice",
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_TTS_OPENAI_API_BASE_URL",
"audio.tts.openai.api_base_url",
os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_TTS_OPENAI_API_KEY = PersistentConfig(
"AUDIO_TTS_OPENAI_API_KEY",
"audio.tts.openai.api_key",
os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_TTS_ENGINE = PersistentConfig(
"AUDIO_TTS_ENGINE",
"audio.tts.engine",
os.getenv("AUDIO_TTS_ENGINE", ""),
)
AUDIO_TTS_MODEL = PersistentConfig(
"AUDIO_TTS_MODEL",
"audio.tts.model",
os.getenv("AUDIO_TTS_MODEL", "tts-1"),
)
AUDIO_TTS_VOICE = PersistentConfig(
"AUDIO_TTS_VOICE",
"audio.tts.voice",
os.getenv("AUDIO_TTS_VOICE", "alloy"),
)
......
......@@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum):
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
......@@ -82,5 +83,9 @@ class ERROR_MESSAGES(str, Enum):
)
WEB_SEARCH_ERROR = (
"Oops! Something went wrong while searching the web. Please try again later."
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
)
OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature."
)
......@@ -9,8 +9,12 @@ import logging
import aiohttp
import requests
import mimetypes
import shutil
import os
import inspect
import asyncio
from fastapi import FastAPI, Request, Depends, status
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi import HTTPException
......@@ -20,26 +24,48 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
from apps.socket.main import app as socket_app
from apps.ollama.main import (
app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion,
)
from apps.openai.main import (
app as openai_app,
get_all_models as get_openai_models,
generate_chat_completion as generate_openai_chat_completion,
)
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app
import asyncio
from pydantic import BaseModel
from typing import List, Optional
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import (
get_admin_user,
get_verified_user,
get_current_user,
get_http_authorization_cred,
)
from apps.rag.utils import rag_messages
from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from apps.rag.utils import get_rag_context, rag_template
from config import (
CONFIG_DATA,
......@@ -60,9 +86,14 @@ from config import (
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
RAG_WEB_SEARCH_ENABLED,
AppConfig,
WEBUI_BUILD_HASH,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
AppConfig,
)
from constants import ERROR_MESSAGES
......@@ -116,27 +147,133 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
app.state.MODELS = {}
origins = ["*"]
# Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next):
# response: Response = await call_next(request)
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
# return response
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
user_message = get_last_user_message(messages)
prompt = (
"History:\n"
+ "\n".join(
[
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
]
)
+ f"\nQuery: {user_message}"
)
print(prompt)
payload = {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
}
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
model = app.state.MODELS[task_model_id]
response = None
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# app.add_middleware(SecurityHeadersMiddleware)
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
# Parse the function response
if content is not None:
print(f"content: {content}")
result = json.loads(content)
print(result)
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
# Check if '__user__' is a parameter of the function
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
function_result = function(
**{
**result["parameters"],
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
}
)
else:
# Call the function without modifying the parameters
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
except Exception as e:
print(f"Error: {e}")
return None
class RAGMiddleware(BaseHTTPMiddleware):
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
......@@ -153,35 +290,98 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
# Remove the citations from the body
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
# Set the task model
task_model_id = data["model"]
if task_model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
try:
response = await get_function_call_response(
messages=data["messages"],
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
if response:
context += ("\n" if context != "" else "") + response
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
# If docs field is present, generate RAG completions
if "docs" in data:
data = {**data}
data["messages"], citations = rag_messages(
rag_context, citations = get_rag_context(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.config.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
del data["docs"]
log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
......@@ -224,25 +424,12 @@ class RAGMiddleware(BaseHTTPMiddleware):
yield data
app.add_middleware(RAGMiddleware)
app.add_middleware(ChatCompletionMiddleware)
class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
"/ollama/api/chat" in request.url.path
or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
model_id = data["model"]
def filter_pipeline(payload, user):
user = {"id": user.id, "name": user.name, "role": user.role}
model_id = payload["model"]
filters = [
model
for model in app.state.MODELS.values()
......@@ -259,17 +446,10 @@ class PipelineMiddleware(BaseHTTPMiddleware):
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
user = None
if len(sorted_filters) > 0:
try:
user = get_current_user(
get_http_authorization_cred(
request.headers.get("Authorization")
)
)
user = {"id": user.id, "name": user.name, "role": user.role}
except:
pass
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
......@@ -286,12 +466,12 @@ class PipelineMiddleware(BaseHTTPMiddleware):
headers=headers,
json={
"user": user,
"body": data,
"body": payload,
},
)
r.raise_for_status()
data = r.json()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
......@@ -299,19 +479,50 @@ class PipelineMiddleware(BaseHTTPMiddleware):
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
if "detail" in res:
raise Exception(r.status_code, res["detail"])
else:
pass
if "chat_id" in data:
del data["chat_id"]
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
return payload
class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
"/ollama/api/chat" in request.url.path
or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
try:
data = filter_pipeline(data, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
......@@ -368,6 +579,9 @@ async def update_embedding_function(request: Request, call_next):
return response
app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app)
......@@ -469,6 +683,237 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models}
@app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)):
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@app.post("/api/task/config/update")
async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
app.state.config.TASK_MODEL = form_data.TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@app.post("/api/task/title/completions")
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
print("generate_title")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 50,
"chat_id": form_data.get("chat_id", None),
"title": True,
}
print(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/query/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query")
if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
)
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
content = search_query_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
}
print(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user
)
return context
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**form_data), user=user
)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
......@@ -490,6 +935,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
print(model_id)
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
......@@ -537,7 +989,11 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
responses = await get_openai_models(raw=True)
print(responses)
urlIdxs = [idx for idx, response in enumerate(responses) if "pipelines" in response]
urlIdxs = [
idx
for idx, response in enumerate(responses)
if response != None and "pipelines" in response
]
return {
"data": [
......@@ -550,6 +1006,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
}
@app.post("/api/pipelines/upload")
async def upload_pipeline(
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
):
print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file
if not file.filename.endswith(".py"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.",
)
upload_folder = f"{CACHE_DIR}/pipelines"
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
try:
# Save the uploaded file
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
with open(file_path, "rb") as f:
files = {"file": f}
r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)
finally:
# Ensure the file is deleted after the upload is completed or on failure
if os.path.exists(file_path):
os.remove(file_path)
class AddPipelineForm(BaseModel):
url: str
urlIdx: int
......@@ -811,11 +1324,20 @@ async def get_app_config():
"auth": WEBUI_AUTH,
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
"enable_web_search": RAG_WEB_SEARCH_ENABLED,
"enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
},
"audio": {
"tts": {
"engine": audio_app.state.config.TTS_ENGINE,
"voice": audio_app.state.config.TTS_VOICE,
},
"stt": {
"engine": audio_app.state.config.STT_ENGINE,
},
},
}
......@@ -860,23 +1382,7 @@ class UrlForm(BaseModel):
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return {
"url": app.state.config.WEBHOOK_URL,
}
@app.get("/api/community_sharing", response_model=bool)
async def get_community_sharing_status(request: Request, user=Depends(get_admin_user)):
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
@app.get("/api/community_sharing/toggle", response_model=bool)
async def toggle_community_sharing(request: Request, user=Depends(get_admin_user)):
webui_app.state.config.ENABLE_COMMUNITY_SHARING = (
not webui_app.state.config.ENABLE_COMMUNITY_SHARING
)
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
return {"url": app.state.config.WEBHOOK_URL}
@app.get("/api/version")
......@@ -894,7 +1400,7 @@ async def get_app_changelog():
@app.get("/api/version/updates")
async def get_app_latest_release_version():
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response:
......
......@@ -57,3 +57,7 @@ black==24.4.2
langfuse==2.33.0
youtube-transcript-api==0.6.2
pytube==15.0.0
extract_msg
pydub
duckduckgo-search~=6.1.5
\ No newline at end of file
......@@ -20,12 +20,12 @@ if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
WEBUI_SECRET_KEY=$(cat "$KEY_FILE")
fi
if [ "$USE_OLLAMA_DOCKER" = "true" ]; then
if [[ "${USE_OLLAMA_DOCKER,,}" == "true" ]]; then
echo "USE_OLLAMA is set to true, starting ollama serve."
ollama serve &
fi
if [ "$USE_CUDA_DOCKER" = "true" ]; then
if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then
echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
fi
......
......@@ -8,6 +8,7 @@ cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key"
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
......@@ -29,4 +30,4 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
:: Execute uvicorn
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
uvicorn main:app --host 0.0.0.0 --port "%PORT%" --forwarded-allow-ips '*'
uvicorn main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*'
......@@ -3,7 +3,48 @@ import hashlib
import json
import re
from datetime import timedelta
from typing import Optional
from typing import Optional, List
def get_last_user_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "user":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def get_last_assistant_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "assistant":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list
or updates the existing system message at the beginning.
:param msg: The message to be added or appended.
:param messages: The list of message dictionaries.
:return: The updated list of message dictionaries.
"""
if messages and messages[0].get("role") == "system":
messages[0]["content"] += f"{content}\n{messages[0]['content']}"
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})
return messages
def get_gravatar_url(email):
......@@ -123,11 +164,25 @@ def parse_ollama_modelfile(model_text):
"repeat_penalty": float,
"temperature": float,
"seed": int,
"stop": str,
"tfs_z": float,
"num_predict": int,
"top_k": int,
"top_p": float,
"num_keep": int,
"typical_p": float,
"presence_penalty": float,
"frequency_penalty": float,
"penalize_newline": bool,
"numa": bool,
"num_batch": int,
"num_gpu": int,
"main_gpu": int,
"low_vram": bool,
"f16_kv": bool,
"vocab_only": bool,
"use_mmap": bool,
"use_mlock": bool,
"num_thread": int,
}
data = {"base_model_id": None, "params": {}}
......@@ -156,10 +211,18 @@ def parse_ollama_modelfile(model_text):
param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
if param_match:
value = param_match.group(1)
try:
if param_type == int:
value = int(value)
elif param_type == float:
value = float(value)
elif param_type == bool:
value = value.lower() == "true"
except Exception as e:
print(e)
continue
data["params"][param] = value
# Parse adapter
......@@ -171,8 +234,14 @@ def parse_ollama_modelfile(model_text):
system_desc_match = re.search(
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
system_desc_match_single = re.search(
r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
)
if system_desc_match:
data["params"]["system"] = system_desc_match.group(1).strip()
elif system_desc_match_single:
data["params"]["system"] = system_desc_match_single.group(1).strip()
# Parse messages
messages = []
......
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
def get_model_id_from_custom_model_id(id: str):
model = Models.get_model_by_id(id)
if model:
return model.id
else:
return id
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment