Unverified Commit 75d71305 authored by perf3ct's avatar perf3ct
Browse files

Merge remote-tracking branch 'upstream/main' into feature-external-db-reconnect

parents ad32a2ef 162643a4
import socketio
import asyncio
from apps.webui.models.users import Users
from utils.utils import decode_token
sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
@sio.event
async def connect(sid, environ, auth):
user = None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()})
@sio.on("user-join")
async def user_join(sid, data):
print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count")
async def user_count(sid):
await sio.emit("user-count", {"count": len(set(USER_POOL))})
def get_models_in_use():
# Aggregate all models in use
models_in_use = []
for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id)
return models_in_use
@sio.on("usage")
async def usage(sid, data):
model_id = data["model"]
# Cancel previous callback if there is one
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["callback"].cancel()
# Store the new usage data and task
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["sids"].append(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
else:
USAGE_POOL[model_id] = {"sids": [sid]}
# Schedule a task to remove the usage data after TIMEOUT_DURATION
USAGE_POOL[model_id]["callback"] = asyncio.create_task(
remove_after_timeout(sid, model_id)
)
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
async def remove_after_timeout(sid, model_id):
try:
await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"])
USAGE_POOL[model_id]["sids"].remove(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id]
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError:
# Task was cancelled due to new 'usage' event
pass
@sio.event
async def disconnect(sid):
if sid in SESSION_POOL:
user_id = SESSION_POOL[sid]
del SESSION_POOL[sid]
USER_POOL[user_id].remove(sid)
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
await sio.emit("user-count", {"count": len(USER_POOL)})
else:
print(f"Unknown session ID {sid} disconnected")
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Tool(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
content = pw.TextField()
specs = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")
......@@ -6,6 +6,7 @@ from apps.webui.routers import (
users,
chats,
documents,
tools,
models,
prompts,
configs,
......@@ -14,6 +15,8 @@ from apps.webui.routers import (
)
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
ADMIN_EMAIL,
WEBUI_AUTH,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
......@@ -24,8 +27,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
AppConfig,
ENABLE_COMMUNITY_SHARING,
AppConfig,
)
app = FastAPI()
......@@ -36,6 +39,12 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
......@@ -47,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.TOOLS = {}
app.add_middleware(
......@@ -63,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
......
......@@ -298,6 +298,15 @@ class ChatTable:
# .limit(limit).offset(skip)
]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
]
def delete_chat_by_id(self, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))
......
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ToolMeta(BaseModel):
description: Optional[str] = None
class ToolModel(BaseModel):
id: str
user_id: str
name: str
content: str
specs: List[dict]
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool.create(**tool.model_dump())
if result:
return tool
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Tools = ToolsTable(DB)
......@@ -270,72 +270,87 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
############################
# ToggleSignUp
# GetAdminDetails
############################
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.ENABLE_SIGNUP
@router.get("/admin/details")
async def get_admin_details(request: Request, user=Depends(get_current_user)):
if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None
print(admin_email, admin_name)
if admin_email:
admin = Users.get_user_by_email(admin_email)
if admin:
admin_name = admin.name
else:
admin = Users.get_first_user()
if admin:
admin_email = admin.email
admin_name = admin.name
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
return request.app.state.config.ENABLE_SIGNUP
return {
"name": admin_name,
"email": admin_email,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
############################
# Default User Role
# ToggleSignUp
############################
@router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.DEFAULT_USER_ROLE
@router.get("/admin/config")
async def get_admin_config(request: Request, user=Depends(get_admin_user)):
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
}
class UpdateRoleForm(BaseModel):
role: str
class AdminConfig(BaseModel):
SHOW_ADMIN_DETAILS: bool
ENABLE_SIGNUP: bool
DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool
@router.post("/signup/user/role")
async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
@router.post("/admin/config")
async def update_admin_config(
request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
):
if form_data.role in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.role
return request.app.state.config.DEFAULT_USER_ROLE
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
############################
# JWT Expiration
############################
@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel):
duration: str
@router.post("/token/expires/update")
async def update_token_expires_duration(
request: Request,
form_data: UpdateJWTExpiresDurationForm,
user=Depends(get_admin_user),
):
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
# Check if the input string matches the pattern
if re.match(pattern, form_data.duration):
request.app.state.config.JWT_EXPIRES_IN = form_data.duration
return request.app.state.config.JWT_EXPIRES_IN
else:
return request.app.state.config.JWT_EXPIRES_IN
if re.match(pattern, form_data.JWT_EXPIRES_IN):
request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN
request.app.state.config.ENABLE_COMMUNITY_SHARING = (
form_data.ENABLE_COMMUNITY_SHARING
)
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
}
############################
......
......@@ -113,6 +113,19 @@ async def get_user_chats(user=Depends(get_current_user)):
]
############################
# GetArchivedChats
############################
@router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id)
]
############################
# GetAllChatsInDB
############################
......@@ -148,7 +161,7 @@ async def get_archived_session_user_chat_list(
############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse])
@router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id)
......@@ -288,6 +301,32 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
return result
############################
# CloneChat
############################
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat_body = json.loads(chat.chat)
updated_chat = {
**chat_body,
"originalChatId": chat.id,
"branchPointMessageId": chat_body["history"]["currentId"],
"title": f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# ArchiveChat
############################
......
......@@ -73,7 +73,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
############################
@router.get("/name/{name}", response_model=Optional[DocumentResponse])
@router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
......@@ -105,7 +105,7 @@ class TagDocumentForm(BaseModel):
tags: List[dict]
@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
......@@ -128,7 +128,7 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
############################
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
@router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
):
......@@ -152,7 +152,7 @@ async def update_doc_by_name(
############################
@router.delete("/name/{name}/delete", response_model=bool)
@router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from config import DATA_DIR
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()
############################
# GetToolkits
############################
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# ExportToolKits
############################
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# CreateNewToolKit
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id)
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetToolkitById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolkitById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id)
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
specs = get_tools_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteToolkitById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
del TOOLS[id]
# delete the toolkit file
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
os.remove(toolkit_path)
return result
......@@ -19,7 +19,12 @@ from apps.webui.models.users import (
from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats
from utils.utils import get_verified_user, get_password_hash, get_admin_user
from utils.utils import (
get_verified_user,
get_password_hash,
get_current_user,
get_admin_user,
)
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
......
......@@ -7,6 +7,8 @@ from pydantic import BaseModel
from fpdf import FPDF
import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
......@@ -26,6 +28,21 @@ async def get_gravatar(
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
class MarkdownForm(BaseModel):
md: str
......@@ -107,3 +124,12 @@ async def download_db(user=Depends(get_admin_user)):
media_type="application/octet-stream",
filename="webui.db",
)
@router.get("/litellm/config")
async def download_litellm_config_yaml(user=Depends(get_admin_user)):
return FileResponse(
f"{DATA_DIR}/litellm/config.yaml",
media_type="application/octet-stream",
filename="config.yaml",
)
from importlib import util
import os
from config import TOOLS_DIR
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {toolkit_id}")
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e
......@@ -180,6 +180,17 @@ WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except:
......@@ -295,7 +306,11 @@ STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png"
if frontend_favicon.exists():
try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
......@@ -353,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# LITELLM_CONFIG
####################################
......@@ -590,6 +613,92 @@ WEBUI_BANNERS = PersistentConfig(
[BannerModel(**banner) for banner in json.loads("[]")],
)
SHOW_ADMIN_DETAILS = PersistentConfig(
"SHOW_ADMIN_DETAILS",
"auth.admin.show",
os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true",
)
ADMIN_EMAIL = PersistentConfig(
"ADMIN_EMAIL",
"auth.admin.email",
os.environ.get("ADMIN_EMAIL", None),
)
####################################
# TASKS
####################################
TASK_MODEL = PersistentConfig(
"TASK_MODEL",
"task.model.default",
os.environ.get("TASK_MODEL", ""),
)
TASK_MODEL_EXTERNAL = PersistentConfig(
"TASK_MODEL_EXTERNAL",
"task.model.external",
os.environ.get("TASK_MODEL_EXTERNAL", ""),
)
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"task.title.prompt_template",
os.environ.get(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"""Here is the query:
{{prompt:middletruncate:8000}}
Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights""",
),
)
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"task.search.prompt_template",
os.environ.get(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
Question:
{{prompt:end:4000}}""",
),
)
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold",
int(
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
)
),
)
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",
os.environ.get(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"""Tools: {{TOOLS}}
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
),
)
####################################
# WEBUI_SECRET_KEY
####################################
......@@ -672,6 +781,12 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
"RAG_EMBEDDING_OPENAI_BATCH_SIZE",
"rag.embedding_openai_batch_size",
os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", 1),
)
RAG_RERANKING_MODEL = PersistentConfig(
"RAG_RERANKING_MODEL",
"rag.reranking_model",
......@@ -766,28 +881,81 @@ YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
)
SEARXNG_QUERY_URL = os.getenv("SEARXNG_QUERY_URL", "")
GOOGLE_PSE_API_KEY = os.getenv("GOOGLE_PSE_API_KEY", "")
GOOGLE_PSE_ENGINE_ID = os.getenv("GOOGLE_PSE_ENGINE_ID", "")
BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY", "")
SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "")
SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true"
SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
ENABLE_RAG_WEB_SEARCH = PersistentConfig(
"ENABLE_RAG_WEB_SEARCH",
"rag.web.search.enable",
os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
)
RAG_WEB_SEARCH_ENGINE = PersistentConfig(
"RAG_WEB_SEARCH_ENGINE",
"rag.web.search.engine",
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
)
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url",
os.getenv("SEARXNG_QUERY_URL", ""),
)
GOOGLE_PSE_API_KEY = PersistentConfig(
"GOOGLE_PSE_API_KEY",
"rag.web.search.google_pse_api_key",
os.getenv("GOOGLE_PSE_API_KEY", ""),
)
GOOGLE_PSE_ENGINE_ID = PersistentConfig(
"GOOGLE_PSE_ENGINE_ID",
"rag.web.search.google_pse_engine_id",
os.getenv("GOOGLE_PSE_ENGINE_ID", ""),
)
BRAVE_SEARCH_API_KEY = PersistentConfig(
"BRAVE_SEARCH_API_KEY",
"rag.web.search.brave_search_api_key",
os.getenv("BRAVE_SEARCH_API_KEY", ""),
)
SERPSTACK_API_KEY = PersistentConfig(
"SERPSTACK_API_KEY",
"rag.web.search.serpstack_api_key",
os.getenv("SERPSTACK_API_KEY", ""),
)
SERPSTACK_HTTPS = PersistentConfig(
"SERPSTACK_HTTPS",
"rag.web.search.serpstack_https",
os.getenv("SERPSTACK_HTTPS", "True").lower() == "true",
)
SERPER_API_KEY = PersistentConfig(
"SERPER_API_KEY",
"rag.web.search.serper_api_key",
os.getenv("SERPER_API_KEY", ""),
)
SERPLY_API_KEY = PersistentConfig(
"SERPLY_API_KEY",
"rag.web.search.serply_api_key",
os.getenv("SERPLY_API_KEY", ""),
)
RAG_WEB_SEARCH_ENABLED = (
SEARXNG_QUERY_URL != ""
or (GOOGLE_PSE_API_KEY != "" and GOOGLE_PSE_ENGINE_ID != "")
or BRAVE_SEARCH_API_KEY != ""
or SERPSTACK_API_KEY != ""
or SERPER_API_KEY != ""
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",
"rag.web.search.result_count",
int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")),
)
RAG_WEB_SEARCH_RESULT_COUNT = int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3"))
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = int(
os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
"RAG_WEB_SEARCH_CONCURRENT_REQUESTS",
"rag.web.search.concurrent_requests",
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
####################################
# Transcribe
####################################
......@@ -855,25 +1023,59 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_OPENAI_API_BASE_URL",
"audio.openai.api_base_url",
os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
"audio.stt.openai.api_base_url",
os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_OPENAI_API_KEY = PersistentConfig(
"AUDIO_OPENAI_API_KEY",
"audio.openai.api_key",
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
AUDIO_STT_OPENAI_API_KEY = PersistentConfig(
"AUDIO_STT_OPENAI_API_KEY",
"audio.stt.openai.api_key",
os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_OPENAI_API_MODEL = PersistentConfig(
"AUDIO_OPENAI_API_MODEL",
"audio.openai.api_model",
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
AUDIO_STT_ENGINE = PersistentConfig(
"AUDIO_STT_ENGINE",
"audio.stt.engine",
os.getenv("AUDIO_STT_ENGINE", ""),
)
AUDIO_STT_MODEL = PersistentConfig(
"AUDIO_STT_MODEL",
"audio.stt.model",
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
)
AUDIO_OPENAI_API_VOICE = PersistentConfig(
"AUDIO_OPENAI_API_VOICE",
"audio.openai.api_voice",
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_TTS_OPENAI_API_BASE_URL",
"audio.tts.openai.api_base_url",
os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_TTS_OPENAI_API_KEY = PersistentConfig(
"AUDIO_TTS_OPENAI_API_KEY",
"audio.tts.openai.api_key",
os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_TTS_ENGINE = PersistentConfig(
"AUDIO_TTS_ENGINE",
"audio.tts.engine",
os.getenv("AUDIO_TTS_ENGINE", ""),
)
AUDIO_TTS_MODEL = PersistentConfig(
"AUDIO_TTS_MODEL",
"audio.tts.model",
os.getenv("AUDIO_TTS_MODEL", "tts-1"),
)
AUDIO_TTS_VOICE = PersistentConfig(
"AUDIO_TTS_VOICE",
"audio.tts.voice",
os.getenv("AUDIO_TTS_VOICE", "alloy"),
)
......
......@@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum):
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
......@@ -82,5 +83,9 @@ class ERROR_MESSAGES(str, Enum):
)
WEB_SEARCH_ERROR = (
"Oops! Something went wrong while searching the web. Please try again later."
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
)
OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature."
)
This diff is collapsed.
......@@ -57,3 +57,7 @@ black==24.4.2
langfuse==2.33.0
youtube-transcript-api==0.6.2
pytube==15.0.0
extract_msg
pydub
duckduckgo-search~=6.1.5
\ No newline at end of file
......@@ -20,12 +20,12 @@ if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
WEBUI_SECRET_KEY=$(cat "$KEY_FILE")
fi
if [ "$USE_OLLAMA_DOCKER" = "true" ]; then
if [[ "${USE_OLLAMA_DOCKER,,}" == "true" ]]; then
echo "USE_OLLAMA is set to true, starting ollama serve."
ollama serve &
fi
if [ "$USE_CUDA_DOCKER" = "true" ]; then
if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then
echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
fi
......
......@@ -8,6 +8,7 @@ cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key"
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
......@@ -29,4 +30,4 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
:: Execute uvicorn
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
uvicorn main:app --host 0.0.0.0 --port "%PORT%" --forwarded-allow-ips '*'
uvicorn main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*'
......@@ -3,7 +3,48 @@ import hashlib
import json
import re
from datetime import timedelta
from typing import Optional
from typing import Optional, List
def get_last_user_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "user":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def get_last_assistant_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "assistant":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list
or updates the existing system message at the beginning.
:param msg: The message to be added or appended.
:param messages: The list of message dictionaries.
:return: The updated list of message dictionaries.
"""
if messages and messages[0].get("role") == "system":
messages[0]["content"] += f"{content}\n{messages[0]['content']}"
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})
return messages
def get_gravatar_url(email):
......@@ -123,11 +164,25 @@ def parse_ollama_modelfile(model_text):
"repeat_penalty": float,
"temperature": float,
"seed": int,
"stop": str,
"tfs_z": float,
"num_predict": int,
"top_k": int,
"top_p": float,
"num_keep": int,
"typical_p": float,
"presence_penalty": float,
"frequency_penalty": float,
"penalize_newline": bool,
"numa": bool,
"num_batch": int,
"num_gpu": int,
"main_gpu": int,
"low_vram": bool,
"f16_kv": bool,
"vocab_only": bool,
"use_mmap": bool,
"use_mlock": bool,
"num_thread": int,
}
data = {"base_model_id": None, "params": {}}
......@@ -156,10 +211,18 @@ def parse_ollama_modelfile(model_text):
param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
if param_match:
value = param_match.group(1)
try:
if param_type == int:
value = int(value)
elif param_type == float:
value = float(value)
elif param_type == bool:
value = value.lower() == "true"
except Exception as e:
print(e)
continue
data["params"][param] = value
# Parse adapter
......@@ -171,8 +234,14 @@ def parse_ollama_modelfile(model_text):
system_desc_match = re.search(
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
system_desc_match_single = re.search(
r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
)
if system_desc_match:
data["params"]["system"] = system_desc_match.group(1).strip()
elif system_desc_match_single:
data["params"]["system"] = system_desc_match_single.group(1).strip()
# Parse messages
messages = []
......
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
def get_model_id_from_custom_model_id(id: str):
model = Models.get_model_by_id(id)
if model:
return model.id
else:
return id
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment