Unverified Commit 8f6f7668 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3595 from open-webui/dev-migration

feat: db migration
parents 97a84918 4e751501
......@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel):
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
......
......@@ -5,6 +5,7 @@ from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
......@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)):
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(
request: Request, form_data: ModelForm, user=Depends(get_admin_user)
request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
):
if form_data.id in request.app.state.MODELS:
raise HTTPException(
......@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
):
model = Models.get_model_by_id(id)
if model:
......
......@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_admin_user)
command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
):
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
......
......@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
......@@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)):
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
):
if not form_data.id.isidentifier():
raise HTTPException(
......@@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
......
......@@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user(
async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user)
):
user = Users.get_user_by_id(user.id)
......@@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
):
user = Users.get_user_by_id(user_id)
......
from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
......@@ -10,7 +9,6 @@ import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url
......@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if not isinstance(DB, SqliteDatabase):
from apps.webui.internal.db import engine
if engine.name != "sqlite":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse(
DB.database,
engine.url.database,
media_type="application/octet-stream",
filename="webui.db",
)
......
......@@ -778,6 +778,7 @@ class BannerModel(BaseModel):
dismissible: bool
timestamp: int
try:
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
banners = [BannerModel(**banner) for banner in banners]
......
import base64
import uuid
import subprocess
from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
......@@ -27,6 +28,7 @@ from fastapi.responses import JSONResponse
from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
......@@ -54,6 +56,7 @@ from apps.webui.main import (
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import Session, SessionLocal
from pydantic import BaseModel
......@@ -125,6 +128,8 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
BACKEND_DIR,
DATABASE_URL,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook
......@@ -167,8 +172,19 @@ https://github.com/open-webui/open-webui
)
def run_migrations():
env = os.environ.copy()
env["DATABASE_URL"] = DATABASE_URL
migration_task = subprocess.run(
["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env
)
if migration_task.returncode > 0:
raise ValueError("Error running migrations")
@asynccontextmanager
async def lifespan(app: FastAPI):
run_migrations()
yield
......@@ -899,6 +915,14 @@ app.add_middleware(
)
@app.middleware("http")
async def commit_session_after_request(request: Request, call_next):
response = await call_next(request)
log.debug("Commit session after request")
Session.commit()
return response
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
......@@ -1744,7 +1768,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
models = await get_all_models()
r = None
......@@ -1782,7 +1808,9 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
models = await get_all_models()
......@@ -2171,6 +2199,12 @@ async def healthcheck():
return {"status": True}
@app.get("/health/db")
async def healthcheck_with_db():
Session.execute(text("SELECT 1;")).all()
return {"status": True}
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
......
Generic single-database configuration.
Create new migrations with
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from apps.webui.models.auths import Auth
from apps.webui.models.chats import Chat
from apps.webui.models.documents import Document
from apps.webui.models.memories import Memory
from apps.webui.models.models import Model
from apps.webui.models.prompts import Prompt
from apps.webui.models.tags import Tag, ChatIdTag
from apps.webui.models.tools import Tool
from apps.webui.models.users import User
from apps.webui.models.files import File
from apps.webui.models.functions import Function
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Auth.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
database_url = os.getenv("DATABASE_URL", None)
if database_url:
config.set_main_option("sqlalchemy.url", database_url)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}
from alembic import op
from sqlalchemy import Inspector
def get_existing_tables():
con = op.get_bind()
inspector = Inspector.from_engine(con)
tables = set(inspector.get_table_names())
return tables
"""init
Revision ID: 7e5b5dc7342b
Revises:
Create Date: 2024-06-24 13:15:33.808998
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
from migrations.util import get_existing_tables
# revision identifiers, used by Alembic.
revision: str = "7e5b5dc7342b"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
existing_tables = set(get_existing_tables())
# ### commands auto generated by Alembic - please adjust! ###
if "auth" not in existing_tables:
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.Text(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "chat" not in existing_tables:
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("chat", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.Text(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
)
if "chatidtag" not in existing_tables:
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "document" not in existing_tables:
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
)
if "file" not in existing_tables:
op.create_table(
"file",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "function" not in existing_tables:
op.create_table(
"function",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("type", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("is_global", sa.Boolean(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "memory" not in existing_tables:
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "model" not in existing_tables:
op.create_table(
"model",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("base_model_id", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "prompt" not in existing_tables:
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
)
if "tag" not in existing_tables:
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "tool" not in existing_tables:
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "user" not in existing_tables:
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("oauth_sub", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
sa.UniqueConstraint("oauth_sub"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user")
op.drop_table("tool")
op.drop_table("tag")
op.drop_table("prompt")
op.drop_table("model")
op.drop_table("memory")
op.drop_table("function")
op.drop_table("file")
op.drop_table("document")
op.drop_table("chatidtag")
op.drop_table("chat")
op.drop_table("auth")
# ### end Alembic commands ###
......@@ -12,6 +12,8 @@ passlib[bcrypt]==1.7.4
requests==2.32.3
aiohttp==3.9.5
sqlalchemy==2.0.30
alembic==1.13.1
peewee==3.17.5
peewee-migrate==1.12.2
psycopg2-binary==2.9.9
......@@ -67,4 +69,9 @@ pytube==15.0.0
extract_msg
pydub
duckduckgo-search~=6.1.7
\ No newline at end of file
duckduckgo-search~=6.1.7
## Tests
docker~=7.1.0
pytest~=8.2.1
pytest-docker~=3.1.1
import pytest
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestAuths(AbstractPostgresTest):
BASE_PATH = "/api/v1/auths"
def setup_class(cls):
super().setup_class()
from apps.webui.models.users import Users
from apps.webui.models.auths import Auths
cls.users = Users
cls.auths = Auths
def test_get_session_user(self):
with mock_webui_user():
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert response.json() == {
"id": "1",
"name": "John Doe",
"email": "john.doe@openwebui.com",
"role": "user",
"profile_image_url": "/user.png",
}
def test_update_profile(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(
self.create_url("/update/profile"),
json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
)
assert response.status_code == 200
db_user = self.users.get_user_by_id(user.id)
assert db_user.name == "John Doe 2"
assert db_user.profile_image_url == "/user2.png"
def test_update_password(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(
self.create_url("/update/password"),
json={"password": "old_password", "new_password": "new_password"},
)
assert response.status_code == 200
old_auth = self.auths.authenticate_user(
"john.doe@openwebui.com", "old_password"
)
assert old_auth is None
new_auth = self.auths.authenticate_user(
"john.doe@openwebui.com", "new_password"
)
assert new_auth is not None
def test_signin(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password=get_password_hash("password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
response = self.fast_api_client.post(
self.create_url("/signin"),
json={"email": "john.doe@openwebui.com", "password": "password"},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == user.id
assert data["name"] == "John Doe"
assert data["email"] == "john.doe@openwebui.com"
assert data["role"] == "user"
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_signup(self):
response = self.fast_api_client.post(
self.create_url("/signup"),
json={
"name": "John Doe",
"email": "john.doe@openwebui.com",
"password": "password",
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] is not None and len(data["id"]) > 0
assert data["name"] == "John Doe"
assert data["email"] == "john.doe@openwebui.com"
assert data["role"] in ["admin", "user", "pending"]
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_add_user(self):
with mock_webui_user():
response = self.fast_api_client.post(
self.create_url("/add"),
json={
"name": "John Doe 2",
"email": "john.doe2@openwebui.com",
"password": "password2",
"role": "admin",
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] is not None and len(data["id"]) > 0
assert data["name"] == "John Doe 2"
assert data["email"] == "john.doe2@openwebui.com"
assert data["role"] == "admin"
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_get_admin_details(self):
self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
with mock_webui_user():
response = self.fast_api_client.get(self.create_url("/admin/details"))
assert response.status_code == 200
assert response.json() == {
"name": "John Doe",
"email": "john.doe@openwebui.com",
}
def test_create_api_key_(self):
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(self.create_url("/api_key"))
assert response.status_code == 200
data = response.json()
assert data["api_key"] is not None
assert len(data["api_key"]) > 0
def test_delete_api_key(self):
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.delete(self.create_url("/api_key"))
assert response.status_code == 200
assert response.json() == True
db_user = self.users.get_user_by_id(user.id)
assert db_user.api_key is None
def test_get_api_key(self):
user = self.auths.insert_new_auth(
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.get(self.create_url("/api_key"))
assert response.status_code == 200
assert response.json() == {"api_key": "abc"}
import uuid
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestChats(AbstractPostgresTest):
BASE_PATH = "/api/v1/chats"
def setup_class(cls):
super().setup_class()
def setup_method(self):
super().setup_method()
from apps.webui.models.chats import ChatForm
from apps.webui.models.chats import Chats
self.chats = Chats
self.chats.insert_new_chat(
"2",
ChatForm(
**{
"chat": {
"name": "chat1",
"description": "chat1 description",
"tags": ["tag1", "tag2"],
"history": {"currentId": "1", "messages": []},
}
}
),
)
def test_get_session_user_chat_list(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
first_chat = response.json()[0]
assert first_chat["id"] is not None
assert first_chat["title"] == "New Chat"
assert first_chat["created_at"] is not None
assert first_chat["updated_at"] is not None
def test_delete_all_user_chats(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url("/"))
assert response.status_code == 200
assert len(self.chats.get_chats()) == 0
def test_get_user_chat_list_by_user_id(self):
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url("/list/user/2"))
assert response.status_code == 200
first_chat = response.json()[0]
assert first_chat["id"] is not None
assert first_chat["title"] == "New Chat"
assert first_chat["created_at"] is not None
assert first_chat["updated_at"] is not None
def test_create_new_chat(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/new"),
json={
"chat": {
"name": "chat2",
"description": "chat2 description",
"tags": ["tag1", "tag2"],
}
},
)
assert response.status_code == 200
data = response.json()
assert data["archived"] is False
assert data["chat"] == {
"name": "chat2",
"description": "chat2 description",
"tags": ["tag1", "tag2"],
}
assert data["user_id"] == "2"
assert data["id"] is not None
assert data["share_id"] is None
assert data["title"] == "New Chat"
assert data["updated_at"] is not None
assert data["created_at"] is not None
assert len(self.chats.get_chats()) == 2
def test_get_user_chats(self):
self.test_get_session_user_chat_list()
def test_get_user_archived_chats(self):
self.chats.archive_all_chats_by_user_id("2")
from apps.webui.internal.db import Session
Session.commit()
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/all/archived"))
assert response.status_code == 200
first_chat = response.json()[0]
assert first_chat["id"] is not None
assert first_chat["title"] == "New Chat"
assert first_chat["created_at"] is not None
assert first_chat["updated_at"] is not None
def test_get_all_user_chats_in_db(self):
with mock_webui_user(id="4"):
response = self.fast_api_client.get(self.create_url("/all/db"))
assert response.status_code == 200
assert len(response.json()) == 1
def test_get_archived_session_user_chat_list(self):
self.test_get_user_archived_chats()
def test_archive_all_chats(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url("/archive/all"))
assert response.status_code == 200
assert len(self.chats.get_archived_chats_by_user_id("2")) == 1
def test_get_shared_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
self.chats.update_chat_share_id_by_id(chat_id, chat_id)
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
assert response.status_code == 200
data = response.json()
assert data["id"] == chat_id
assert data["chat"] == {
"name": "chat1",
"description": "chat1 description",
"tags": ["tag1", "tag2"],
"history": {"currentId": "1", "messages": []},
}
assert data["id"] == chat_id
assert data["share_id"] == chat_id
assert data["title"] == "New Chat"
def test_get_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
assert response.status_code == 200
data = response.json()
assert data["id"] == chat_id
assert data["chat"] == {
"name": "chat1",
"description": "chat1 description",
"tags": ["tag1", "tag2"],
"history": {"currentId": "1", "messages": []},
}
assert data["share_id"] is None
assert data["title"] == "New Chat"
assert data["user_id"] == "2"
def test_update_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url(f"/{chat_id}"),
json={
"chat": {
"name": "chat2",
"description": "chat2 description",
"tags": ["tag2", "tag4"],
"title": "Just another title",
}
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == chat_id
assert data["chat"] == {
"name": "chat2",
"title": "Just another title",
"description": "chat2 description",
"tags": ["tag2", "tag4"],
"history": {"currentId": "1", "messages": []},
}
assert data["share_id"] is None
assert data["title"] == "Just another title"
assert data["user_id"] == "2"
def test_delete_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
assert response.status_code == 200
assert response.json() is True
def test_clone_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
assert response.status_code == 200
data = response.json()
assert data["id"] != chat_id
assert data["chat"] == {
"branchPointMessageId": "1",
"description": "chat1 description",
"history": {"currentId": "1", "messages": []},
"name": "chat1",
"originalChatId": chat_id,
"tags": ["tag1", "tag2"],
"title": "Clone of New Chat",
}
assert data["share_id"] is None
assert data["title"] == "Clone of New Chat"
assert data["user_id"] == "2"
def test_archive_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
assert response.status_code == 200
chat = self.chats.get_chat_by_id(chat_id)
assert chat.archived is True
def test_share_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
assert response.status_code == 200
chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is not None
def test_delete_shared_chat_by_id(self):
chat_id = self.chats.get_chats()[0].id
share_id = str(uuid.uuid4())
self.chats.update_chat_share_id_by_id(chat_id, share_id)
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
assert response.status_code
chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is None
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestDocuments(AbstractPostgresTest):
BASE_PATH = "/api/v1/documents"
def setup_class(cls):
super().setup_class()
from apps.webui.models.documents import Documents
cls.documents = Documents
def test_documents(self):
# Empty database
assert len(self.documents.get_docs()) == 0
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
# Create a new document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name",
"title": "doc title",
"collection_name": "custom collection",
"filename": "doc_name.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name"
assert len(self.documents.get_docs()) == 1
# Get the document
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/doc?name=doc_name"))
assert response.status_code == 200
data = response.json()
assert data["collection_name"] == "custom collection"
assert data["name"] == "doc_name"
assert data["title"] == "doc title"
assert data["filename"] == "doc_name.pdf"
assert data["content"] == {}
# Create another document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name 2",
"title": "doc title 2",
"collection_name": "custom collection 2",
"filename": "doc_name2.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name 2"
assert len(self.documents.get_docs()) == 2
# Get all documents
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 2
# Update the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/update?name=doc_name"),
json={"name": "doc_name rework", "title": "updated title"},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["title"] == "updated title"
# Tag the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/tags"),
json={
"name": "doc_name rework",
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}],
},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["content"] == {
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
}
assert len(self.documents.get_docs()) == 2
# Delete the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/doc/delete?name=doc_name rework")
)
assert response.status_code == 200
assert len(self.documents.get_docs()) == 1
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestModels(AbstractPostgresTest):
BASE_PATH = "/api/v1/models"
def setup_class(cls):
super().setup_class()
from apps.webui.models.models import Model
cls.models = Model
def test_models(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/add"),
json={
"id": "my-model",
"base_model_id": "base-model-id",
"name": "Hello World",
"meta": {
"profile_image_url": "/favicon.png",
"description": "description",
"capabilities": None,
"model_config": {},
},
"params": {},
},
)
assert response.status_code == 200
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 1
with mock_webui_user(id="2"):
response = self.fast_api_client.get(
self.create_url(query_params={"id": "my-model"})
)
assert response.status_code == 200
data = response.json()[0]
assert data["id"] == "my-model"
assert data["name"] == "Hello World"
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/delete?id=my-model")
)
assert response.status_code == 200
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestPrompts(AbstractPostgresTest):
BASE_PATH = "/api/v1/prompts"
def test_prompts(self):
# Get all prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
# Create a two new prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"command": "/my-command",
"title": "Hello World",
"content": "description",
},
)
assert response.status_code == 200
with mock_webui_user(id="3"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"command": "/my-command2",
"title": "Hello World 2",
"content": "description 2",
},
)
assert response.status_code == 200
# Get all prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 2
# Get prompt by command
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/command/my-command"))
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command"
assert data["title"] == "Hello World"
assert data["content"] == "description"
assert data["user_id"] == "2"
# Update prompt
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/command/my-command2/update"),
json={
"command": "irrelevant for request",
"title": "Hello World Updated",
"content": "description Updated",
},
)
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command2"
assert data["title"] == "Hello World Updated"
assert data["content"] == "description Updated"
assert data["user_id"] == "3"
# Get prompt by command
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/command/my-command2"))
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command2"
assert data["title"] == "Hello World Updated"
assert data["content"] == "description Updated"
assert data["user_id"] == "3"
# Delete prompt
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/command/my-command/delete")
)
assert response.status_code == 200
# Get all prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment