Commit df09d083 authored by Jonathan Rohde's avatar Jonathan Rohde
Browse files

feat(sqlalchemy): Replace peewee with sqlalchemy

parent 8dac2a21
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestModels(AbstractPostgresTest):
BASE_PATH = "/api/v1/models"
def setup_class(cls):
super().setup_class()
from apps.webui.models.models import Model
cls.models = Model
def test_models(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/add"),
json={
"id": "my-model",
"base_model_id": "base-model-id",
"name": "Hello World",
"meta": {
"profile_image_url": "/favicon.png",
"description": "description",
"capabilities": None,
"model_config": {},
},
"params": {},
},
)
assert response.status_code == 200
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 1
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/my-model"))
assert response.status_code == 200
data = response.json()
assert data["id"] == "my-model"
assert data["name"] == "Hello World"
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/delete?id=my-model")
)
assert response.status_code == 200
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestPrompts(AbstractPostgresTest):
BASE_PATH = "/api/v1/prompts"
def test_prompts(self):
# Get all prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
# Create a two new prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"command": "/my-command",
"title": "Hello World",
"content": "description",
},
)
assert response.status_code == 200
with mock_webui_user(id="3"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"command": "/my-command2",
"title": "Hello World 2",
"content": "description 2",
},
)
assert response.status_code == 200
# Get all prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 2
# Get prompt by command
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/command/my-command"))
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command"
assert data["title"] == "Hello World"
assert data["content"] == "description"
assert data["user_id"] == "2"
# Update prompt
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/command/my-command2/update"),
json={
"command": "irrelevant for request",
"title": "Hello World Updated",
"content": "description Updated",
},
)
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command2"
assert data["title"] == "Hello World Updated"
assert data["content"] == "description Updated"
assert data["user_id"] == "3"
# Delete prompt
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/command/my-command/delete")
)
assert response.status_code == 200
# Get all prompts
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 1
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
def _get_user_by_id(data, param):
return next((item for item in data if item["id"] == param), None)
def _assert_user(data, id, **kwargs):
user = _get_user_by_id(data, id)
assert user is not None
comparison_data = {
"name": f"user {id}",
"email": f"user{id}@openwebui.com",
"profile_image_url": f"/user{id}.png",
"role": "user",
**kwargs,
}
for key, value in comparison_data.items():
assert user[key] == value
class TestUsers(AbstractPostgresTest):
BASE_PATH = "/api/v1/users"
def setup_class(cls):
super().setup_class()
from apps.webui.models.users import Users
cls.users = Users
def setup_method(self):
super().setup_method()
self.users.insert_new_user(
self.db_session,
id="1",
name="user 1",
email="user1@openwebui.com",
profile_image_url="/user1.png",
role="user",
)
self.users.insert_new_user(
self.db_session,
id="2",
name="user 2",
email="user2@openwebui.com",
profile_image_url="/user2.png",
role="user",
)
def test_users(self):
# Get all users
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert len(response.json()) == 2
data = response.json()
_assert_user(data, "1")
_assert_user(data, "2")
# update role
with mock_webui_user(id="3"):
response = self.fast_api_client.post(
self.create_url("/update/role"), json={"id": "2", "role": "admin"}
)
assert response.status_code == 200
_assert_user([response.json()], "2", role="admin")
# Get all users
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert len(response.json()) == 2
data = response.json()
_assert_user(data, "1")
_assert_user(data, "2", role="admin")
# Get (empty) user settings
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/user/settings"))
assert response.status_code == 200
assert response.json() is None
# Update user settings
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/user/settings/update"),
json={
"ui": {"attr1": "value1", "attr2": "value2"},
"model_config": {"attr3": "value3", "attr4": "value4"},
},
)
assert response.status_code == 200
# Get user settings
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/user/settings"))
assert response.status_code == 200
assert response.json() == {
"ui": {"attr1": "value1", "attr2": "value2"},
"model_config": {"attr3": "value3", "attr4": "value4"},
}
# Get (empty) user info
with mock_webui_user(id="1"):
response = self.fast_api_client.get(self.create_url("/user/info"))
assert response.status_code == 200
assert response.json() is None
# Update user info
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/user/info/update"),
json={"attr1": "value1", "attr2": "value2"},
)
assert response.status_code == 200
# Get user info
with mock_webui_user(id="1"):
response = self.fast_api_client.get(self.create_url("/user/info"))
assert response.status_code == 200
assert response.json() == {"attr1": "value1", "attr2": "value2"}
# Get user by id
with mock_webui_user(id="1"):
response = self.fast_api_client.get(self.create_url("/2"))
assert response.status_code == 200
assert response.json() == {"name": "user 2", "profile_image_url": "/user2.png"}
# Update user by id
with mock_webui_user(id="1"):
response = self.fast_api_client.post(
self.create_url("/2/update"),
json={
"name": "user 2 updated",
"email": "user2-updated@openwebui.com",
"profile_image_url": "/user2-updated.png",
},
)
assert response.status_code == 200
# Get all users
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert len(response.json()) == 2
data = response.json()
_assert_user(data, "1")
_assert_user(
data,
"2",
role="admin",
name="user 2 updated",
email="user2-updated@openwebui.com",
profile_image_url="/user2-updated.png",
)
# Delete user by id
with mock_webui_user(id="1"):
response = self.fast_api_client.delete(self.create_url("/2"))
assert response.status_code == 200
# Get all users
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert len(response.json()) == 1
data = response.json()
_assert_user(data, "1")
import logging
import os
import time
import docker
import pytest
from docker import DockerClient
from pytest_docker.plugin import get_docker_ip
from fastapi.testclient import TestClient
from sqlalchemy import text, create_engine
log = logging.getLogger(__name__)
def get_fast_api_client():
from main import app
with TestClient(app) as c:
return c
class AbstractIntegrationTest:
BASE_PATH = None
def create_url(self, path):
if self.BASE_PATH is None:
raise Exception("BASE_PATH is not set")
parts = self.BASE_PATH.split("/")
parts = [part.strip() for part in parts if part.strip() != ""]
path_parts = path.split("/")
path_parts = [part.strip() for part in path_parts if part.strip() != ""]
return "/".join(parts + path_parts)
@classmethod
def setup_class(cls):
pass
def setup_method(self):
pass
@classmethod
def teardown_class(cls):
pass
def teardown_method(self):
pass
class AbstractPostgresTest(AbstractIntegrationTest):
DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
docker_client: DockerClient
def get_db(self):
from apps.webui.internal.db import SessionLocal
return SessionLocal()
@classmethod
def _create_db_url(cls, env_vars_postgres: dict) -> str:
host = get_docker_ip()
user = env_vars_postgres["POSTGRES_USER"]
pw = env_vars_postgres["POSTGRES_PASSWORD"]
port = 8081
db = env_vars_postgres["POSTGRES_DB"]
return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
@classmethod
def setup_class(cls):
super().setup_class()
try:
env_vars_postgres = {
"POSTGRES_USER": "user",
"POSTGRES_PASSWORD": "example",
"POSTGRES_DB": "openwebui",
}
cls.docker_client = docker.from_env()
cls.docker_client.containers.run(
"postgres:16.2",
detach=True,
environment=env_vars_postgres,
name=cls.DOCKER_CONTAINER_NAME,
ports={5432: ("0.0.0.0", 8081)},
command="postgres -c log_statement=all",
)
time.sleep(0.5)
database_url = cls._create_db_url(env_vars_postgres)
os.environ["DATABASE_URL"] = database_url
retries = 10
db = None
while retries > 0:
try:
from config import BACKEND_DIR
db = create_engine(database_url, pool_pre_ping=True)
db = db.connect()
log.info("postgres is ready!")
break
except Exception as e:
log.warning(e)
time.sleep(3)
retries -= 1
if db:
# import must be after setting env!
cls.fast_api_client = get_fast_api_client()
db.close()
else:
raise Exception("Could not connect to Postgres")
except Exception as ex:
log.error(ex)
cls.teardown_class()
pytest.fail(f"Could not setup test environment: {ex}")
def _check_db_connection(self):
retries = 10
while retries > 0:
try:
self.db_session.execute(text("SELECT 1"))
self.db_session.commit()
break
except Exception as e:
self.db_session.rollback()
log.warning(e)
time.sleep(3)
retries -= 1
def setup_method(self):
super().setup_method()
self.db_session = self.get_db()
self._check_db_connection()
@classmethod
def teardown_class(cls) -> None:
super().teardown_class()
cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
def teardown_method(self):
# rollback everything not yet committed
self.db_session.commit()
# truncate all tables
tables = [
"auth",
"chat",
"chatidtag",
"document",
"memory",
"model",
"prompt",
"tag",
'"user"',
]
for table in tables:
self.db_session.execute(text(f"TRUNCATE TABLE {table}"))
self.db_session.commit()
from contextlib import contextmanager
from fastapi import FastAPI
@contextmanager
def mock_webui_user(**kwargs):
from apps.webui.main import app
with mock_user(app, **kwargs):
yield
@contextmanager
def mock_user(app: FastAPI, **kwargs):
from utils.utils import (
get_current_user,
get_verified_user,
get_admin_user,
get_current_user_by_api_key,
)
from apps.webui.models.users import User
def create_user():
user_parameters = {
"id": "1",
"name": "John Doe",
"email": "john.doe@openwebui.com",
"role": "user",
"profile_image_url": "/user.png",
"last_active_at": 1627351200,
"updated_at": 1627351200,
"created_at": 162735120,
**kwargs,
}
return User(**user_parameters)
app.dependency_overrides = {
get_current_user: create_user,
get_verified_user: create_user,
get_admin_user: create_user,
get_current_user_by_api_key: create_user,
}
yield
app.dependency_overrides = {}
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends, Request from fastapi import HTTPException, status, Depends, Request
from sqlalchemy.orm import Session
from apps.webui.internal.db import get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
...@@ -77,6 +79,7 @@ def get_http_authorization_cred(auth_header: str): ...@@ -77,6 +79,7 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user( def get_current_user(
request: Request, request: Request,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
db=Depends(get_db),
): ):
token = None token = None
...@@ -91,19 +94,19 @@ def get_current_user( ...@@ -91,19 +94,19 @@ def get_current_user(
# auth by api key # auth by api key
if token.startswith("sk-"): if token.startswith("sk-"):
return get_current_user_by_api_key(token) return get_current_user_by_api_key(db, token)
# auth by jwt token # auth by jwt token
data = decode_token(token) data = decode_token(token)
if data != None and "id" in data: if data != None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(db, data["id"])
if user is None: if user is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
else: else:
Users.update_user_last_active_by_id(user.id) Users.update_user_last_active_by_id(db, user.id)
return user return user
else: else:
raise HTTPException( raise HTTPException(
...@@ -112,8 +115,8 @@ def get_current_user( ...@@ -112,8 +115,8 @@ def get_current_user(
) )
def get_current_user_by_api_key(api_key: str): def get_current_user_by_api_key(db: Session, api_key: str):
user = Users.get_user_by_api_key(api_key) user = Users.get_user_by_api_key(db, api_key)
if user is None: if user is None:
raise HTTPException( raise HTTPException(
...@@ -121,7 +124,7 @@ def get_current_user_by_api_key(api_key: str): ...@@ -121,7 +124,7 @@ def get_current_user_by_api_key(api_key: str):
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
else: else:
Users.update_user_last_active_by_id(user.id) Users.update_user_last_active_by_id(db, user.id)
return user return user
......
...@@ -63,10 +63,7 @@ export const getModelInfos = async (token: string = '') => { ...@@ -63,10 +63,7 @@ export const getModelInfos = async (token: string = '') => {
export const getModelById = async (token: string, id: string) => { export const getModelById = async (token: string, id: string) => {
let error = null; let error = null;
const searchParams = new URLSearchParams(); const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}`, {
searchParams.append('id', id);
const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment