".github/vscode:/vscode.git/clone" did not exist on "46ab56a468cbc43c1ab48190e27fc42195f4c60b"
Commit f9e3c47d authored by Michael Poluektov's avatar Michael Poluektov
Browse files

rebase

parents 49b4211c 24ef5af2
...@@ -50,10 +50,7 @@ router = APIRouter() ...@@ -50,10 +50,7 @@ router = APIRouter()
@router.post("/") @router.post("/")
def upload_file( def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
try: try:
unsanitized_filename = file.filename unsanitized_filename = file.filename
......
...@@ -233,7 +233,10 @@ async def delete_function_by_id( ...@@ -233,7 +233,10 @@ async def delete_function_by_id(
# delete the function file # delete the function file
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
os.remove(function_path) try:
os.remove(function_path)
except:
pass
return result return result
......
...@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel): ...@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel):
@router.post("/add", response_model=Optional[MemoryModel]) @router.post("/add", response_model=Optional[MemoryModel])
async def add_memory( async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user) request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
): ):
memory = Memories.insert_new_memory(user.id, form_data.content) memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
......
...@@ -5,6 +5,7 @@ from typing import List, Union, Optional ...@@ -5,6 +5,7 @@ from typing import List, Union, Optional
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
...@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)): ...@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)):
@router.post("/add", response_model=Optional[ModelModel]) @router.post("/add", response_model=Optional[ModelModel])
async def add_new_model( async def add_new_model(
request: Request, form_data: ModelForm, user=Depends(get_admin_user) request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
): ):
if form_data.id in request.app.state.MODELS: if form_data.id in request.app.state.MODELS:
raise HTTPException( raise HTTPException(
...@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): ...@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/update", response_model=Optional[ModelModel]) @router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id( async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user) request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
): ):
model = Models.get_model_by_id(id) model = Models.get_model_by_id(id)
if model: if model:
......
...@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): ...@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel]) @router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command( async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_admin_user) command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
): ):
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt: if prompt:
......
...@@ -6,7 +6,6 @@ from fastapi import APIRouter ...@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.models.users import Users from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.utils import load_toolkit_module_by_id
...@@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)): ...@@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)):
@router.post("/create", response_model=Optional[ToolResponse]) @router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit( async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user) request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
): ):
if not form_data.id.isidentifier(): if not form_data.id.isidentifier():
raise HTTPException( raise HTTPException(
...@@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ...@@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[ToolModel]) @router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id( async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user) request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
): ):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
......
...@@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)): ...@@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/info/update", response_model=Optional[dict]) @router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user( async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user) form_data: dict, user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(user.id) user = Users.get_user_by_id(user.id)
...@@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): ...@@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.post("/{user_id}/update", response_model=Optional[UserModel]) @router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id( async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user) user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
): ):
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
......
from fastapi import APIRouter, UploadFile, File, Response from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
...@@ -10,7 +9,6 @@ import markdown ...@@ -10,7 +9,6 @@ import markdown
import black import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url from utils.misc import calculate_sha256, get_gravatar_url
...@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)): ...@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
if not isinstance(DB, SqliteDatabase): from apps.webui.internal.db import engine
if engine.name != "sqlite":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE, detail=ERROR_MESSAGES.DB_NOT_SQLITE,
) )
return FileResponse( return FileResponse(
DB.database, engine.url.database,
media_type="application/octet-stream", media_type="application/octet-stream",
filename="webui.db", filename="webui.db",
) )
......
...@@ -393,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig( ...@@ -393,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
) )
OAUTH_USERNAME_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.username_claim",
os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
)
OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
def load_oauth_providers(): def load_oauth_providers():
OAUTH_PROVIDERS.clear() OAUTH_PROVIDERS.clear()
...@@ -438,16 +450,27 @@ load_oauth_providers() ...@@ -438,16 +450,27 @@ load_oauth_providers()
STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve() STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png" frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
if frontend_favicon.exists(): if frontend_favicon.exists():
try: try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e: except Exception as e:
logging.error(f"An error occurred: {e}") logging.error(f"An error occurred: {e}")
else: else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}") logging.warning(f"Frontend favicon not found at {frontend_favicon}")
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
if frontend_splash.exists():
try:
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend splash not found at {frontend_splash}")
#################################### ####################################
# CUSTOM_NAME # CUSTOM_NAME
#################################### ####################################
...@@ -472,6 +495,19 @@ if CUSTOM_NAME: ...@@ -472,6 +495,19 @@ if CUSTOM_NAME:
r.raw.decode_content = True r.raw.decode_content = True
shutil.copyfileobj(r.raw, f) shutil.copyfileobj(r.raw, f)
if "splash" in data:
url = (
f"https://api.openwebui.com{data['splash']}"
if data["splash"][0] == "/"
else data["splash"]
)
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(f"{STATIC_DIR}/splash.png", "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
WEBUI_NAME = data["name"] WEBUI_NAME = data["name"]
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
...@@ -766,6 +802,7 @@ class BannerModel(BaseModel): ...@@ -766,6 +802,7 @@ class BannerModel(BaseModel):
dismissible: bool dismissible: bool
timestamp: int timestamp: int
try: try:
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]")) banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
banners = [BannerModel(**banner) for banner in banners] banners = [BannerModel(**banner) for banner in banners]
...@@ -1318,3 +1355,7 @@ AUDIO_TTS_VOICE = PersistentConfig( ...@@ -1318,3 +1355,7 @@ AUDIO_TTS_VOICE = PersistentConfig(
#################################### ####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
...@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum): ...@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
OLLAMA_API_DISABLED = ( OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature." "The Ollama API is disabled. Please enable it to use this feature."
) )
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}"
TITLE_GENERATION = "Title Generation"
EMOJI_GENERATION = "Emoji Generation"
QUERY_GENERATION = "Query Generation"
FUNCTION_CALLING = "Function Calling"
import base64 import base64
import uuid import uuid
import subprocess
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
...@@ -27,6 +28,7 @@ from fastapi.responses import JSONResponse ...@@ -27,6 +28,7 @@ from fastapi.responses import JSONResponse
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
...@@ -54,6 +56,7 @@ from apps.webui.main import ( ...@@ -54,6 +56,7 @@ from apps.webui.main import (
get_pipe_models, get_pipe_models,
generate_function_chat_completion, generate_function_chat_completion,
) )
from apps.webui.internal.db import Session, SessionLocal
from pydantic import BaseModel from pydantic import BaseModel
...@@ -125,8 +128,10 @@ from config import ( ...@@ -125,8 +128,10 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE, WEBUI_SESSION_COOKIE_SECURE,
AppConfig, AppConfig,
BACKEND_DIR,
DATABASE_URL,
) )
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook from utils.webhook import post_webhook
if SAFE_MODE: if SAFE_MODE:
...@@ -167,8 +172,20 @@ https://github.com/open-webui/open-webui ...@@ -167,8 +172,20 @@ https://github.com/open-webui/open-webui
) )
def run_migrations():
try:
from alembic.config import Config
from alembic import command
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
run_migrations()
yield yield
...@@ -285,6 +302,7 @@ async def get_function_call_response( ...@@ -285,6 +302,7 @@ async def get_function_call_response(
user, user,
model, model,
__event_emitter__=None, __event_emitter__=None,
__event_call__=None,
): ):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
...@@ -311,7 +329,7 @@ async def get_function_call_response( ...@@ -311,7 +329,7 @@ async def get_function_call_response(
{"role": "user", "content": f"Query: {prompt}"}, {"role": "user", "content": f"Query: {prompt}"},
], ],
"stream": False, "stream": False,
"function": True, "task": TASKS.FUNCTION_CALLING,
} }
try: try:
...@@ -324,7 +342,6 @@ async def get_function_call_response( ...@@ -324,7 +342,6 @@ async def get_function_call_response(
response = None response = None
try: try:
response = await generate_chat_completions(form_data=payload, user=user) response = await generate_chat_completions(form_data=payload, user=user)
content = None content = None
if hasattr(response, "body_iterator"): if hasattr(response, "body_iterator"):
...@@ -429,6 +446,13 @@ async def get_function_call_response( ...@@ -429,6 +446,13 @@ async def get_function_call_response(
"__event_emitter__": __event_emitter__, "__event_emitter__": __event_emitter__,
} }
if "__event_call__" in sig.parameters:
# Call the function with the '__event_call__' parameter included
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(function): if inspect.iscoroutinefunction(function):
function_result = await function(**params) function_result = await function(**params)
else: else:
...@@ -452,7 +476,9 @@ async def get_function_call_response( ...@@ -452,7 +476,9 @@ async def get_function_call_response(
return None, None, False return None, None, False
async def chat_completion_functions_handler(body, model, user, __event_emitter__): async def chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None skip_files = None
filter_ids = get_filter_function_ids(model) filter_ids = get_filter_function_ids(model)
...@@ -518,12 +544,19 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__ ...@@ -518,12 +544,19 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
**params, **params,
"__model__": model, "__model__": model,
} }
if "__event_emitter__" in sig.parameters: if "__event_emitter__" in sig.parameters:
params = { params = {
**params, **params,
"__event_emitter__": __event_emitter__, "__event_emitter__": __event_emitter__,
} }
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(inlet): if inspect.iscoroutinefunction(inlet):
body = await inlet(**params) body = await inlet(**params)
else: else:
...@@ -540,7 +573,9 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__ ...@@ -540,7 +573,9 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
return body, {} return body, {}
async def chat_completion_tools_handler(body, model, user, __event_emitter__): async def chat_completion_tools_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None skip_files = None
contexts = [] contexts = []
...@@ -563,6 +598,7 @@ async def chat_completion_tools_handler(body, model, user, __event_emitter__): ...@@ -563,6 +598,7 @@ async def chat_completion_tools_handler(body, model, user, __event_emitter__):
user=user, user=user,
model=model, model=model,
__event_emitter__=__event_emitter__, __event_emitter__=__event_emitter__,
__event_call__=__event_call__,
) )
print(file_handler) print(file_handler)
...@@ -660,6 +696,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -660,6 +696,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
to=session_id, to=session_id,
) )
async def __event_call__(data):
response = await sio.call(
"chat-events",
{"chat_id": chat_id, "message_id": message_id, "data": data},
to=session_id,
)
return response
# Initialize data_items to store additional data to be sent to the client # Initialize data_items to store additional data to be sent to the client
data_items = [] data_items = []
...@@ -669,7 +713,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -669,7 +713,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try: try:
body, flags = await chat_completion_functions_handler( body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__ body, model, user, __event_emitter__, __event_call__
) )
except Exception as e: except Exception as e:
return JSONResponse( return JSONResponse(
...@@ -679,7 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -679,7 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
try: try:
body, flags = await chat_completion_tools_handler( body, flags = await chat_completion_tools_handler(
body, model, user, __event_emitter__ body, model, user, __event_emitter__, __event_call__
) )
contexts.extend(flags.get("contexts", [])) contexts.extend(flags.get("contexts", []))
...@@ -834,9 +878,8 @@ def filter_pipeline(payload, user): ...@@ -834,9 +878,8 @@ def filter_pipeline(payload, user):
pass pass
if "pipeline" not in app.state.MODELS[model_id]: if "pipeline" not in app.state.MODELS[model_id]:
for key in ["title", "task", "function"]: if "task" in payload:
if key in payload: del payload["task"]
del payload[key]
return payload return payload
...@@ -901,6 +944,14 @@ app.add_middleware( ...@@ -901,6 +944,14 @@ app.add_middleware(
) )
@app.middleware("http")
async def commit_session_after_request(request: Request, call_next):
response = await call_next(request)
log.debug("Commit session after request")
Session.commit()
return response
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0: if len(app.state.MODELS) == 0:
...@@ -977,12 +1028,16 @@ async def get_all_models(): ...@@ -977,12 +1028,16 @@ async def get_all_models():
model["info"] = custom_model.model_dump() model["info"] = custom_model.model_dump()
else: else:
owned_by = "openai" owned_by = "openai"
pipe = None
for model in models: for model in models:
if ( if (
custom_model.base_model_id == model["id"] custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0] or custom_model.base_model_id == model["id"].split(":")[0]
): ):
owned_by = model["owned_by"] owned_by = model["owned_by"]
if "pipe" in model:
pipe = model["pipe"]
break break
models.append( models.append(
...@@ -994,11 +1049,11 @@ async def get_all_models(): ...@@ -994,11 +1049,11 @@ async def get_all_models():
"owned_by": owned_by, "owned_by": owned_by,
"info": custom_model.model_dump(), "info": custom_model.model_dump(),
"preset": True, "preset": True,
**({"pipe": pipe} if pipe is not None else {}),
} }
) )
app.state.MODELS = {model["id"]: model for model in models} app.state.MODELS = {model["id"]: model for model in models}
webui_app.state.MODELS = app.state.MODELS webui_app.state.MODELS = app.state.MODELS
return models return models
...@@ -1133,6 +1188,14 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1133,6 +1188,14 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
to=data["session_id"], to=data["session_id"],
) )
async def __event_call__(data):
response = await sio.call(
"chat-events",
{"chat_id": data["chat_id"], "message_id": data["id"], "data": data},
to=data["session_id"],
)
return response
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"): if function is not None and hasattr(function, "valves"):
...@@ -1220,6 +1283,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1220,6 +1283,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
"__event_emitter__": __event_emitter__, "__event_emitter__": __event_emitter__,
} }
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(outlet): if inspect.iscoroutinefunction(outlet):
data = await outlet(**params) data = await outlet(**params)
else: else:
...@@ -1337,7 +1406,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -1337,7 +1406,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False, "stream": False,
"max_tokens": 50, "max_tokens": 50,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"title": True, "task": TASKS.TITLE_GENERATION,
} }
log.debug(payload) log.debug(payload)
...@@ -1400,7 +1469,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -1400,7 +1469,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"max_tokens": 30, "max_tokens": 30,
"task": True, "task": TASKS.QUERY_GENERATION,
} }
print(payload) print(payload)
...@@ -1467,7 +1536,7 @@ Message: """{{prompt}}""" ...@@ -1467,7 +1536,7 @@ Message: """{{prompt}}"""
"stream": False, "stream": False,
"max_tokens": 4, "max_tokens": 4,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": True, "task": TASKS.EMOJI_GENERATION,
} }
log.debug(payload) log.debug(payload)
...@@ -1742,7 +1811,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use ...@@ -1742,7 +1811,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves") @app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves( async def get_pipeline_valves(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user) urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
): ):
models = await get_all_models() models = await get_all_models()
r = None r = None
...@@ -1780,7 +1851,9 @@ async def get_pipeline_valves( ...@@ -1780,7 +1851,9 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec") @app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec( async def get_pipeline_valves_spec(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user) urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
): ):
models = await get_all_models() models = await get_all_models()
...@@ -2066,7 +2139,8 @@ async def oauth_callback(provider: str, request: Request, response: Response): ...@@ -2066,7 +2139,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
if existing_user: if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_url = user_data.get("picture", "") picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url: if picture_url:
# Download the profile image into a base64 string # Download the profile image into a base64 string
try: try:
...@@ -2086,6 +2160,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): ...@@ -2086,6 +2160,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
picture_url = "" picture_url = ""
if not picture_url: if not picture_url:
picture_url = "/user.png" picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = ( role = (
"admin" "admin"
if Users.get_num_users() == 0 if Users.get_num_users() == 0
...@@ -2096,7 +2171,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): ...@@ -2096,7 +2171,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
password=get_password_hash( password=get_password_hash(
str(uuid.uuid4()) str(uuid.uuid4())
), # Random password, not used ), # Random password, not used
name=user_data.get("name", "User"), name=user_data.get(username_claim, "User"),
profile_image_url=picture_url, profile_image_url=picture_url,
role=role, role=role,
oauth_sub=provider_sub, oauth_sub=provider_sub,
...@@ -2154,7 +2229,7 @@ async def get_opensearch_xml(): ...@@ -2154,7 +2229,7 @@ async def get_opensearch_xml():
<ShortName>{WEBUI_NAME}</ShortName> <ShortName>{WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description> <Description>Search {WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding> <InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image> <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/static/favicon.png</Image>
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/> <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm> <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
</OpenSearchDescription> </OpenSearchDescription>
...@@ -2167,6 +2242,12 @@ async def healthcheck(): ...@@ -2167,6 +2242,12 @@ async def healthcheck():
return {"status": True} return {"status": True}
@app.get("/health/db")
async def healthcheck_with_db():
Session.execute(text("SELECT 1;")).all()
return {"status": True}
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
......
Generic single-database configuration.
Create new migrations with
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from apps.webui.models.auths import Auth
from apps.webui.models.chats import Chat
from apps.webui.models.documents import Document
from apps.webui.models.memories import Memory
from apps.webui.models.models import Model
from apps.webui.models.prompts import Prompt
from apps.webui.models.tags import Tag, ChatIdTag
from apps.webui.models.tools import Tool
from apps.webui.models.users import User
from apps.webui.models.files import File
from apps.webui.models.functions import Function
from config import DATABASE_URL
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Auth.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
DB_URL = DATABASE_URL
if DB_URL:
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}
from alembic import op
from sqlalchemy import Inspector
def get_existing_tables():
con = op.get_bind()
inspector = Inspector.from_engine(con)
tables = set(inspector.get_table_names())
return tables
"""init
Revision ID: 7e5b5dc7342b
Revises:
Create Date: 2024-06-24 13:15:33.808998
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
from migrations.util import get_existing_tables
# revision identifiers, used by Alembic.
revision: str = "7e5b5dc7342b"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
existing_tables = set(get_existing_tables())
# ### commands auto generated by Alembic - please adjust! ###
if "auth" not in existing_tables:
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.Text(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "chat" not in existing_tables:
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("chat", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.Text(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
)
if "chatidtag" not in existing_tables:
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "document" not in existing_tables:
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
)
if "file" not in existing_tables:
op.create_table(
"file",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "function" not in existing_tables:
op.create_table(
"function",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("type", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("is_global", sa.Boolean(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "memory" not in existing_tables:
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "model" not in existing_tables:
op.create_table(
"model",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("base_model_id", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "prompt" not in existing_tables:
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
)
if "tag" not in existing_tables:
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "tool" not in existing_tables:
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "user" not in existing_tables:
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("oauth_sub", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
sa.UniqueConstraint("oauth_sub"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user")
op.drop_table("tool")
op.drop_table("tag")
op.drop_table("prompt")
op.drop_table("model")
op.drop_table("memory")
op.drop_table("function")
op.drop_table("file")
op.drop_table("document")
op.drop_table("chatidtag")
op.drop_table("chat")
op.drop_table("auth")
# ### end Alembic commands ###
...@@ -12,7 +12,9 @@ passlib[bcrypt]==1.7.4 ...@@ -12,7 +12,9 @@ passlib[bcrypt]==1.7.4
requests==2.32.3 requests==2.32.3
aiohttp==3.9.5 aiohttp==3.9.5
peewee==3.17.5 sqlalchemy==2.0.30
alembic==1.13.2
peewee==3.17.6
peewee-migrate==1.12.2 peewee-migrate==1.12.2
psycopg2-binary==2.9.9 psycopg2-binary==2.9.9
PyMySQL==1.1.1 PyMySQL==1.1.1
...@@ -49,7 +51,7 @@ pyxlsb==1.0.10 ...@@ -49,7 +51,7 @@ pyxlsb==1.0.10
xlrd==2.0.1 xlrd==2.0.1
validators==0.28.1 validators==0.28.1
opencv-python-headless==4.9.0.80 opencv-python-headless==4.10.0.84
rapidocr-onnxruntime==1.3.22 rapidocr-onnxruntime==1.3.22
fpdf2==2.7.9 fpdf2==2.7.9
...@@ -61,10 +63,15 @@ PyJWT[crypto]==2.8.0 ...@@ -61,10 +63,15 @@ PyJWT[crypto]==2.8.0
authlib==1.3.1 authlib==1.3.1
black==24.4.2 black==24.4.2
langfuse==2.36.2 langfuse==2.38.0
youtube-transcript-api==0.6.2 youtube-transcript-api==0.6.2
pytube==15.0.0 pytube==15.0.0
extract_msg extract_msg
pydub pydub
duckduckgo-search~=6.1.7 duckduckgo-search~=6.1.7
\ No newline at end of file
## Tests
docker~=7.1.0
pytest~=8.2.2
pytest-docker~=3.1.1
import pytest
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestAuths(AbstractPostgresTest):
BASE_PATH = "/api/v1/auths"
def setup_class(cls):
super().setup_class()
from apps.webui.models.users import Users
from apps.webui.models.auths import Auths
cls.users = Users
cls.auths = Auths
def test_get_session_user(self):
with mock_webui_user():
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert response.json() == {
"id": "1",
"name": "John Doe",
"email": "john.doe@openwebui.com",
"role": "user",
"profile_image_url": "/user.png",
}
def test_update_profile(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(
self.create_url("/update/profile"),
json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
)
assert response.status_code == 200
db_user = self.users.get_user_by_id(user.id)
assert db_user.name == "John Doe 2"
assert db_user.profile_image_url == "/user2.png"
def test_update_password(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(
self.create_url("/update/password"),
json={"password": "old_password", "new_password": "new_password"},
)
assert response.status_code == 200
old_auth = self.auths.authenticate_user(
"john.doe@openwebui.com", "old_password"
)
assert old_auth is None
new_auth = self.auths.authenticate_user(
"john.doe@openwebui.com", "new_password"
)
assert new_auth is not None
def test_signin(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password=get_password_hash("password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
response = self.fast_api_client.post(
self.create_url("/signin"),
json={"email": "john.doe@openwebui.com", "password": "password"},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == user.id
assert data["name"] == "John Doe"
assert data["email"] == "john.doe@openwebui.com"
assert data["role"] == "user"
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_signup(self):
response = self.fast_api_client.post(
self.create_url("/signup"),
json={
"name": "John Doe",
"email": "john.doe@openwebui.com",
"password": "password",
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] is not None and len(data["id"]) > 0
assert data["name"] == "John Doe"
assert data["email"] == "john.doe@openwebui.com"
assert data["role"] in ["admin", "user", "pending"]
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_add_user(self):
with mock_webui_user():
response = self.fast_api_client.post(
self.create_url("/add"),
json={
"name": "John Doe 2",
"email": "john.doe2@openwebui.com",
"password": "password2",
"role": "admin",
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] is not None and len(data["id"]) > 0
assert data["name"] == "John Doe 2"
assert data["email"] == "john.doe2@openwebui.com"
assert data["role"] == "admin"
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_get_admin_details(self):
self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
with mock_webui_user():
response = self.fast_api_client.get(self.create_url("/admin/details"))
assert response.status_code == 200
assert response.json() == {
"name": "John Doe",
"email": "john.doe@openwebui.com",
}
def test_create_api_key_(self):
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(self.create_url("/api_key"))
assert response.status_code == 200
data = response.json()
assert data["api_key"] is not None
assert len(data["api_key"]) > 0
def test_delete_api_key(self):
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.delete(self.create_url("/api_key"))
assert response.status_code == 200
assert response.json() == True
db_user = self.users.get_user_by_id(user.id)
assert db_user.api_key is None
def test_get_api_key(self):
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.get(self.create_url("/api_key"))
assert response.status_code == 200
assert response.json() == {"api_key": "abc"}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment