Unverified Commit 8f6f7668 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3595 from open-webui/dev-migration

feat: db migration
parents 97a84918 4e751501
...@@ -67,6 +67,28 @@ jobs: ...@@ -67,6 +67,28 @@ jobs:
path: compose-logs.txt path: compose-logs.txt
if-no-files-found: ignore if-no-files-found: ignore
pytest:
name: Run backend tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r backend/requirements.txt
- name: pytest run
run: |
ls -al
cd backend
PYTHONPATH=. pytest . -o log_cli=true -o log_cli_level=INFO
migration_test: migration_test:
name: Run Migration Tests name: Run Migration Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
...@@ -171,7 +193,7 @@ jobs: ...@@ -171,7 +193,7 @@ jobs:
fi fi
# Check that service will reconnect to postgres when connection will be closed # Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health) status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
if [[ "$status_code" -ne 200 ]] ; then if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check" echo "Server has failed before postgres reconnect check"
exit 1 exit 1
...@@ -183,7 +205,7 @@ jobs: ...@@ -183,7 +205,7 @@ jobs:
cur = conn.cursor(); \ cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')" cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health) status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
if [[ "$status_code" -ne 200 ]] ; then if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code" echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1 exit 1
......
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = REPLACE_WITH_DATABASE_URL
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
import os import os
import logging import logging
import json import json
from contextlib import contextmanager
from peewee import *
from peewee_migrate import Router from peewee_migrate import Router
from apps.webui.internal.wrappers import register_connection from apps.webui.internal.wrappers import register_connection
from typing import Optional, Any
from typing_extensions import Self
from sqlalchemy import create_engine, types, Dialect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.sql.type_api import _T
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(TextField): class JSONField(types.TypeDecorator):
impl = types.Text
cache_ok = True
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
return json.dumps(value)
def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
if value is not None:
return json.loads(value)
def copy(self, **kw: Any) -> Self:
return JSONField(self.impl.length)
def db_value(self, value): def db_value(self, value):
return json.dumps(value) return json.dumps(value)
...@@ -30,25 +51,57 @@ else: ...@@ -30,25 +51,57 @@ else:
pass pass
# The `register_connection` function encapsulates the logic for setting up # Workaround to handle the peewee migration
# the database connection based on the connection string, while `connect` # This is required to ensure the peewee migration is handled before the alembic migration
# is a Peewee-specific method to manage the connection state and avoid errors def handle_peewee_migration():
# when a connection is already open. try:
try: db = register_connection(DATABASE_URL)
DB = register_connection(DATABASE_URL) migrate_dir = BACKEND_DIR / "apps" / "webui" / "internal" / "migrations"
log.info(f"Connected to a {DB.__class__.__name__} database.") router = Router(db, logger=log, migrate_dir=migrate_dir)
except Exception as e: router.run()
db.close()
# check if db connection has been closed
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}") log.error(f"Failed to initialize the database connection: {e}")
raise raise
router = Router( finally:
DB, # Properly closing the database connection
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations", if db and not db.is_closed():
logger=log, db.close()
# Assert if db connection has been closed
assert db.is_closed(), "Database connection is still open."
handle_peewee_migration()
SQLALCHEMY_DATABASE_URL = DATABASE_URL
if "sqlite" in SQLALCHEMY_DATABASE_URL:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
) )
router.run() Base = declarative_base()
try: Session = scoped_session(SessionLocal)
DB.connect(reuse_if_open=True)
except OperationalError as e:
log.info(f"Failed to connect to database again due to: {e}") # Dependency
pass def get_session():
db = SessionLocal()
try:
yield db
finally:
db.close()
get_db = contextmanager(get_session)
"""Peewee migrations -- 017_add_user_oauth_sub.py. """Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name):: Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name > Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name > Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL > migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args > migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator) > migrator.create_model(Model) # Create a model (could be used as decorator)
...@@ -21,7 +18,6 @@ Some examples (model - class or model name):: ...@@ -21,7 +18,6 @@ Some examples (model - class or model name)::
> migrator.drop_index(model, *col_names) > migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names) > migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints) > migrator.drop_constraints(model, *constraints)
""" """
from contextlib import suppress from contextlib import suppress
......
# Database Migrations
This directory contains all the database migrations for the web app.
Migrations are done using the [`peewee-migrate`](https://github.com/klen/peewee_migrate) library.
Migrations are automatically ran at app startup.
## Creating a migration
Have you made a change to the schema of an existing model?
You will need to create a migration file to ensure that existing databases are updated for backwards compatibility.
1. Have a database file (`webui.db`) that has the old schema prior to any of your changes.
2. Make your changes to the models.
3. From the `backend` directory, run the following command:
```bash
pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
```
- `$SQLITE_DB` should be the path to the database file.
- `$MIGRATION_NAME` should be a descriptive name for the migration.
4. The migration file will be created in the `apps/web/internal/migrations` directory.
...@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute ...@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from sqlalchemy.orm import Session
from apps.webui.routers import ( from apps.webui.routers import (
auths, auths,
users, users,
......
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Union, Optional from typing import Optional
import time
import uuid import uuid
import logging import logging
from peewee import * from sqlalchemy import String, Column, Boolean, Text
from apps.webui.models.users import UserModel, Users from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password from utils.utils import verify_password
from apps.webui.internal.db import DB from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -20,14 +19,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -20,14 +19,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class Auth(Model): class Auth(Base):
id = CharField(unique=True) __tablename__ = "auth"
email = CharField()
password = TextField()
active = BooleanField()
class Meta: id = Column(String, primary_key=True)
database = DB email = Column(String)
password = Column(Text)
active = Column(Boolean)
class AuthModel(BaseModel): class AuthModel(BaseModel):
...@@ -94,9 +92,6 @@ class AddUserForm(SignupForm): ...@@ -94,9 +92,6 @@ class AddUserForm(SignupForm):
class AuthsTable: class AuthsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Auth])
def insert_new_auth( def insert_new_auth(
self, self,
...@@ -107,6 +102,8 @@ class AuthsTable: ...@@ -107,6 +102,8 @@ class AuthsTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db:
log.info("insert_new_auth") log.info("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
...@@ -114,12 +111,16 @@ class AuthsTable: ...@@ -114,12 +111,16 @@ class AuthsTable:
auth = AuthModel( auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True} **{"id": id, "email": email, "password": password, "active": True}
) )
result = Auth.create(**auth.model_dump()) result = Auth(**auth.model_dump())
db.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub id, name, email, profile_image_url, role, oauth_sub
) )
db.commit()
db.refresh(result)
if result and user: if result and user:
return user return user
else: else:
...@@ -128,7 +129,9 @@ class AuthsTable: ...@@ -128,7 +129,9 @@ class AuthsTable:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
try: try:
auth = Auth.get(Auth.email == email, Auth.active == True) with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth: if auth:
if verify_password(password, auth.password): if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id) user = Users.get_user_by_id(auth.id)
...@@ -155,7 +158,8 @@ class AuthsTable: ...@@ -155,7 +158,8 @@ class AuthsTable:
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}") log.info(f"authenticate_user_by_trusted_header: {email}")
try: try:
auth = Auth.get(Auth.email == email, Auth.active == True) with get_db() as db:
auth = db.query(Auth).filter(email=email, active=True).first()
if auth: if auth:
user = Users.get_user_by_id(auth.id) user = Users.get_user_by_id(auth.id)
return user return user
...@@ -164,31 +168,33 @@ class AuthsTable: ...@@ -164,31 +168,33 @@ class AuthsTable:
def update_user_password_by_id(self, id: str, new_password: str) -> bool: def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try: try:
query = Auth.update(password=new_password).where(Auth.id == id) with get_db() as db:
result = query.execute()
result = (
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
def update_email_by_id(self, id: str, email: str) -> bool: def update_email_by_id(self, id: str, email: str) -> bool:
try: try:
query = Auth.update(email=email).where(Auth.id == id) with get_db() as db:
result = query.execute()
result = db.query(Auth).filter_by(id=id).update({"email": email})
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
def delete_auth_by_id(self, id: str) -> bool: def delete_auth_by_id(self, id: str) -> bool:
try: try:
with get_db() as db:
# Delete User # Delete User
result = Users.delete_user_by_id(id) result = Users.delete_user_by_id(id)
if result: if result:
# Delete Auth db.query(Auth).filter_by(id=id).delete()
query = Auth.delete().where(Auth.id == id)
query.execute() # Remove the rows, return number of rows removed.
return True return True
else: else:
...@@ -197,4 +203,4 @@ class AuthsTable: ...@@ -197,4 +203,4 @@ class AuthsTable:
return False return False
Auths = AuthsTable(DB) Auths = AuthsTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json import json
import uuid import uuid
import time import time
from apps.webui.internal.db import DB from sqlalchemy import Column, String, BigInteger, Boolean, Text
from apps.webui.internal.db import Base, get_db
#################### ####################
# Chat DB Schema # Chat DB Schema
#################### ####################
class Chat(Model): class Chat(Base):
id = CharField(unique=True) __tablename__ = "chat"
user_id = CharField()
title = TextField()
chat = TextField() # Save Chat JSON as Text
created_at = BigIntegerField() id = Column(String, primary_key=True)
updated_at = BigIntegerField() user_id = Column(String)
title = Column(Text)
chat = Column(Text) # Save Chat JSON as Text
share_id = CharField(null=True, unique=True) created_at = Column(BigInteger)
archived = BooleanField(default=False) updated_at = Column(BigInteger)
class Meta: share_id = Column(Text, unique=True, nullable=True)
database = DB archived = Column(Boolean, default=False)
class ChatModel(BaseModel): class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str id: str
user_id: str user_id: str
title: str title: str
...@@ -75,18 +77,19 @@ class ChatTitleIdResponse(BaseModel): ...@@ -75,18 +77,19 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def __init__(self, db):
self.db = db
db.create_tables([Chat])
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"title": ( "title": (
form_data.chat["title"] if "title" in form_data.chat else "New Chat" form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
), ),
"chat": json.dumps(form_data.chat), "chat": json.dumps(form_data.chat),
"created_at": int(time.time()), "created_at": int(time.time()),
...@@ -94,26 +97,32 @@ class ChatTable: ...@@ -94,26 +97,32 @@ class ChatTable:
} }
) )
result = Chat.create(**chat.model_dump()) result = Chat(**chat.model_dump())
return chat if result else None db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try: try:
query = Chat.update( with get_db() as db:
chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat", chat_obj = db.get(Chat, id)
updated_at=int(time.time()), chat_obj.chat = json.dumps(chat)
).where(Chat.id == id) chat_obj.title = chat["title"] if "title" in chat else "New Chat"
query.execute() chat_obj.updated_at = int(time.time())
db.commit()
chat = Chat.get(Chat.id == id) db.refresh(chat_obj)
return ChatModel(**model_to_dict(chat))
except: return ChatModel.model_validate(chat_obj)
except Exception as e:
return None return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share # Get the existing chat to share
chat = Chat.get(Chat.id == chat_id) chat = db.get(Chat, chat_id)
# Check if the chat is already shared # Check if the chat is already shared
if chat.share_id: if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared") return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
...@@ -128,37 +137,40 @@ class ChatTable: ...@@ -128,37 +137,40 @@ class ChatTable:
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
shared_result = Chat.create(**shared_chat.model_dump()) shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result)
db.commit()
db.refresh(shared_result)
# Update the original chat with the share_id # Update the original chat with the share_id
result = ( result = (
Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute() db.query(Chat)
.filter_by(id=chat_id)
.update({"share_id": shared_chat.id})
) )
return shared_chat if (shared_result and result) else None return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db:
print("update_shared_chat_by_id") print("update_shared_chat_by_id")
chat = Chat.get(Chat.id == chat_id) chat = db.get(Chat, chat_id)
print(chat) print(chat)
chat.title = chat.title
chat.chat = chat.chat
db.commit()
db.refresh(chat)
query = Chat.update( return self.get_chat_by_id(chat.share_id)
title=chat.title,
chat=chat.chat,
).where(Chat.id == chat.share_id)
query.execute()
chat = Chat.get(Chat.id == chat.share_id)
return ChatModel(**model_to_dict(chat))
except: except:
return None return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try: try:
query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}") with get_db() as db:
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
return True return True
except: except:
return False return False
...@@ -167,40 +179,33 @@ class ChatTable: ...@@ -167,40 +179,33 @@ class ChatTable:
self, id: str, share_id: Optional[str] self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
query = Chat.update( with get_db() as db:
share_id=share_id,
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id) chat = db.get(Chat, id)
return ChatModel(**model_to_dict(chat)) chat.share_id = share_id
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except: except:
return None return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = self.get_chat_by_id(id) with get_db() as db:
query = Chat.update(
archived=(not chat.archived),
).where(Chat.id == id)
query.execute() chat = db.get(Chat, id)
chat.archived = not chat.archived
chat = Chat.get(Chat.id == id) db.commit()
return ChatModel(**model_to_dict(chat)) db.refresh(chat)
return ChatModel.model_validate(chat)
except: except:
return None return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool: def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try: try:
chats = self.get_chats_by_user_id(user_id) with get_db() as db:
for chat in chats:
query = Chat.update(
archived=True,
).where(Chat.id == chat.id)
query.execute()
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
return True return True
except: except:
return False return False
...@@ -208,15 +213,16 @@ class ChatTable: ...@@ -208,15 +213,16 @@ class ChatTable:
def get_archived_chat_list_by_user_id( def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
return [ with get_db() as db:
ChatModel(**model_to_dict(chat))
for chat in Chat.select() all_chats = (
.where(Chat.archived == True) db.query(Chat)
.where(Chat.user_id == user_id) .filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
# .limit(limit) # .limit(limit).offset(skip)
# .offset(skip) .all()
] )
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id( def get_chat_list_by_user_id(
self, self,
...@@ -225,92 +231,98 @@ class ChatTable: ...@@ -225,92 +231,98 @@ class ChatTable:
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
) -> List[ChatModel]: ) -> List[ChatModel]:
if include_archived: with get_db() as db:
return [ query = db.query(Chat).filter_by(user_id=user_id)
ChatModel(**model_to_dict(chat)) if not include_archived:
for chat in Chat.select() query = query.filter_by(archived=False)
.where(Chat.user_id == user_id) all_chats = (
.order_by(Chat.updated_at.desc()) query.order_by(Chat.updated_at.desc())
# .limit(limit) # .limit(limit).offset(skip)
# .offset(skip) .all()
] )
else: return [ChatModel.model_validate(chat) for chat in all_chats]
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
def get_chat_list_by_chat_ids( def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50 self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat)) with get_db() as db:
for chat in Chat.select()
.where(Chat.archived == False) all_chats = (
.where(Chat.id.in_(chat_ids)) db.query(Chat)
.filter(Chat.id.in_(chat_ids))
.filter_by(archived=False)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
] .all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = Chat.get(Chat.id == id) with get_db() as db:
return ChatModel(**model_to_dict(chat))
chat = db.get(Chat, id)
return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = Chat.get(Chat.share_id == id) with get_db() as db:
chat = db.query(Chat).filter_by(share_id=id).first()
if chat: if chat:
chat = Chat.get(Chat.id == id) return self.get_chat_by_id(id)
return ChatModel(**model_to_dict(chat))
else: else:
return None return None
except: except Exception as e:
return None return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try: try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id) with get_db() as db:
return ChatModel(**model_to_dict(chat))
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
return [ with get_db() as db:
ChatModel(**model_to_dict(chat))
for chat in Chat.select().order_by(Chat.updated_at.desc()) all_chats = (
db.query(Chat)
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] .order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [ with get_db() as db:
ChatModel(**model_to_dict(chat))
for chat in Chat.select() all_chats = (
.where(Chat.user_id == user_id) db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip) )
] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [ with get_db() as db:
ChatModel(**model_to_dict(chat))
for chat in Chat.select() all_chats = (
.where(Chat.archived == True) db.query(Chat)
.where(Chat.user_id == user_id) .filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
] )
return [ChatModel.model_validate(chat) for chat in all_chats]
def delete_chat_by_id(self, id: str) -> bool: def delete_chat_by_id(self, id: str) -> bool:
try: try:
query = Chat.delete().where((Chat.id == id)) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter_by(id=id).delete()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
...@@ -318,8 +330,9 @@ class ChatTable: ...@@ -318,8 +330,9 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
...@@ -328,28 +341,28 @@ class ChatTable: ...@@ -328,28 +341,28 @@ class ChatTable:
def delete_chats_by_user_id(self, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str) -> bool:
try: try:
self.delete_shared_chats_by_user_id(user_id) with get_db() as db:
query = Chat.delete().where(Chat.user_id == user_id) self.delete_shared_chats_by_user_id(user_id)
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter_by(user_id=user_id).delete()
return True return True
except: except:
return False return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
shared_chat_ids = [
f"shared-{chat.id}"
for chat in Chat.select().where(Chat.user_id == user_id)
]
query = Chat.delete().where(Chat.user_id << shared_chat_ids) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed.
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
return True return True
except: except:
return False return False
Chats = ChatTable(DB) Chats = ChatTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import * from typing import List, Optional
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time import time
import logging import logging
from utils.utils import decode_token from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB from apps.webui.internal.db import Base, get_db
import json import json
...@@ -22,20 +19,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -22,20 +19,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class Document(Model): class Document(Base):
collection_name = CharField(unique=True) __tablename__ = "document"
name = CharField(unique=True)
title = TextField()
filename = TextField()
content = TextField(null=True)
user_id = CharField()
timestamp = BigIntegerField()
class Meta: collection_name = Column(String, primary_key=True)
database = DB name = Column(String, unique=True)
title = Column(Text)
filename = Column(Text)
content = Column(Text, nullable=True)
user_id = Column(String)
timestamp = Column(BigInteger)
class DocumentModel(BaseModel): class DocumentModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
collection_name: str collection_name: str
name: str name: str
title: str title: str
...@@ -72,13 +70,12 @@ class DocumentForm(DocumentUpdateForm): ...@@ -72,13 +70,12 @@ class DocumentForm(DocumentUpdateForm):
class DocumentsTable: class DocumentsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Document])
def insert_new_doc( def insert_new_doc(
self, user_id: str, form_data: DocumentForm self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
with get_db() as db:
document = DocumentModel( document = DocumentModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
...@@ -88,9 +85,12 @@ class DocumentsTable: ...@@ -88,9 +85,12 @@ class DocumentsTable:
) )
try: try:
result = Document.create(**document.model_dump()) result = Document(**document.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return document return DocumentModel.model_validate(result)
else: else:
return None return None
except: except:
...@@ -98,31 +98,35 @@ class DocumentsTable: ...@@ -98,31 +98,35 @@ class DocumentsTable:
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try: try:
document = Document.get(Document.name == name) with get_db() as db:
return DocumentModel(**model_to_dict(document))
document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None
except: except:
return None return None
def get_docs(self) -> List[DocumentModel]: def get_docs(self) -> List[DocumentModel]:
with get_db() as db:
return [ return [
DocumentModel(**model_to_dict(doc)) DocumentModel.model_validate(doc) for doc in db.query(Document).all()
for doc in Document.select()
# .limit(limit).offset(skip)
] ]
def update_doc_by_name( def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
try: try:
query = Document.update( with get_db() as db:
title=form_data.title,
name=form_data.name, db.query(Document).filter_by(name=name).update(
timestamp=int(time.time()), {
).where(Document.name == name) "title": form_data.title,
query.execute() "name": form_data.name,
"timestamp": int(time.time()),
doc = Document.get(Document.name == form_data.name) }
return DocumentModel(**model_to_dict(doc)) )
db.commit()
return self.get_doc_by_name(form_data.name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
...@@ -135,26 +139,28 @@ class DocumentsTable: ...@@ -135,26 +139,28 @@ class DocumentsTable:
doc_content = json.loads(doc.content if doc.content else "{}") doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated} doc_content = {**doc_content, **updated}
query = Document.update( with get_db() as db:
content=json.dumps(doc_content),
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
doc = Document.get(Document.name == name) db.query(Document).filter_by(name=name).update(
return DocumentModel(**model_to_dict(doc)) {
"content": json.dumps(doc_content),
"timestamp": int(time.time()),
}
)
db.commit()
return self.get_doc_by_name(name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def delete_doc_by_name(self, name: str) -> bool: def delete_doc_by_name(self, name: str) -> bool:
try: try:
query = Document.delete().where((Document.name == name)) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed.
db.query(Document).filter_by(name=name).delete()
return True return True
except: except:
return False return False
Documents = DocumentsTable(DB) Documents = DocumentsTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
import logging import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import Column, String, BigInteger, Text
from apps.webui.internal.db import JSONField, Base, get_db
import json import json
...@@ -18,15 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -18,15 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class File(Model): class File(Base):
id = CharField(unique=True) __tablename__ = "file"
user_id = CharField()
filename = TextField()
meta = JSONField()
created_at = BigIntegerField()
class Meta: id = Column(String, primary_key=True)
database = DB user_id = Column(String)
filename = Column(Text)
meta = Column(JSONField)
created_at = Column(BigInteger)
class FileModel(BaseModel): class FileModel(BaseModel):
...@@ -36,6 +36,8 @@ class FileModel(BaseModel): ...@@ -36,6 +36,8 @@ class FileModel(BaseModel):
meta: dict meta: dict
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -57,11 +59,10 @@ class FileForm(BaseModel): ...@@ -57,11 +59,10 @@ class FileForm(BaseModel):
class FilesTable: class FilesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([File])
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db:
file = FileModel( file = FileModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
...@@ -71,9 +72,12 @@ class FilesTable: ...@@ -71,9 +72,12 @@ class FilesTable:
) )
try: try:
result = File.create(**file.model_dump()) result = File(**file.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return file return FileModel.model_validate(result)
else: else:
return None return None
except Exception as e: except Exception as e:
...@@ -81,32 +85,38 @@ class FilesTable: ...@@ -81,32 +85,38 @@ class FilesTable:
return None return None
def get_file_by_id(self, id: str) -> Optional[FileModel]: def get_file_by_id(self, id: str) -> Optional[FileModel]:
with get_db() as db:
try: try:
file = File.get(File.id == id) file = db.get(File, id)
return FileModel(**model_to_dict(file)) return FileModel.model_validate(file)
except: except:
return None return None
def get_files(self) -> List[FileModel]: def get_files(self) -> List[FileModel]:
return [FileModel(**model_to_dict(file)) for file in File.select()] with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def delete_file_by_id(self, id: str) -> bool: def delete_file_by_id(self, id: str) -> bool:
try:
query = File.delete().where((File.id == id))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
try:
db.query(File).filter_by(id=id).delete()
return True return True
except: except:
return False return False
def delete_all_files(self) -> bool: def delete_all_files(self) -> bool:
try:
query = File.delete()
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
try:
db.query(File).delete()
return True return True
except: except:
return False return False
Files = FilesTable(DB) Files = FilesTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
import logging import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import Column, String, Text, BigInteger, Boolean
from apps.webui.internal.db import JSONField, Base, get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
...@@ -21,21 +22,20 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -21,21 +22,20 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class Function(Model): class Function(Base):
id = CharField(unique=True) __tablename__ = "function"
user_id = CharField()
name = TextField()
type = TextField()
content = TextField()
meta = JSONField()
valves = JSONField()
is_active = BooleanField(default=False)
is_global = BooleanField(default=False)
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta: id = Column(String, primary_key=True)
database = DB user_id = Column(String)
name = Column(Text)
type = Column(Text)
content = Column(Text)
meta = Column(JSONField)
valves = Column(JSONField)
is_active = Column(Boolean)
is_global = Column(Boolean)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class FunctionMeta(BaseModel): class FunctionMeta(BaseModel):
...@@ -55,6 +55,8 @@ class FunctionModel(BaseModel): ...@@ -55,6 +55,8 @@ class FunctionModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -85,13 +87,11 @@ class FunctionValves(BaseModel): ...@@ -85,13 +87,11 @@ class FunctionValves(BaseModel):
class FunctionsTable: class FunctionsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Function])
def insert_new_function( def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
function = FunctionModel( function = FunctionModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
...@@ -103,9 +103,13 @@ class FunctionsTable: ...@@ -103,9 +103,13 @@ class FunctionsTable:
) )
try: try:
result = Function.create(**function.model_dump()) with get_db() as db:
result = Function(**function.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return function return FunctionModel.model_validate(result)
else: else:
return None return None
except Exception as e: except Exception as e:
...@@ -114,52 +118,60 @@ class FunctionsTable: ...@@ -114,52 +118,60 @@ class FunctionsTable:
def get_function_by_id(self, id: str) -> Optional[FunctionModel]: def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try: try:
function = Function.get(Function.id == id) with get_db() as db:
return FunctionModel(**model_to_dict(function))
function = db.get(Function, id)
return FunctionModel.model_validate(function)
except: except:
return None return None
def get_functions(self, active_only=False) -> List[FunctionModel]: def get_functions(self, active_only=False) -> List[FunctionModel]:
with get_db() as db:
if active_only: if active_only:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel.model_validate(function)
for function in Function.select().where(Function.is_active == True) for function in db.query(Function).filter_by(is_active=True).all()
] ]
else: else:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel.model_validate(function)
for function in Function.select() for function in db.query(Function).all()
] ]
def get_functions_by_type( def get_functions_by_type(
self, type: str, active_only=False self, type: str, active_only=False
) -> List[FunctionModel]: ) -> List[FunctionModel]:
with get_db() as db:
if active_only: if active_only:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel.model_validate(function)
for function in Function.select().where( for function in db.query(Function)
Function.type == type, Function.is_active == True .filter_by(type=type, is_active=True)
) .all()
] ]
else: else:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel.model_validate(function)
for function in Function.select().where(Function.type == type) for function in db.query(Function).filter_by(type=type).all()
] ]
def get_global_filter_functions(self) -> List[FunctionModel]: def get_global_filter_functions(self) -> List[FunctionModel]:
with get_db() as db:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel.model_validate(function)
for function in Function.select().where( for function in db.query(Function)
Function.type == "filter", .filter_by(type="filter", is_active=True, is_global=True)
Function.is_active == True, .all()
Function.is_global == True,
)
] ]
def get_function_valves_by_id(self, id: str) -> Optional[dict]: def get_function_valves_by_id(self, id: str) -> Optional[dict]:
with get_db() as db:
try: try:
function = Function.get(Function.id == id) function = db.get(Function, id)
return function.valves if function.valves else {} return function.valves if function.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
...@@ -168,21 +180,22 @@ class FunctionsTable: ...@@ -168,21 +180,22 @@ class FunctionsTable:
def update_function_valves_by_id( def update_function_valves_by_id(
self, id: str, valves: dict self, id: str, valves: dict
) -> Optional[FunctionValves]: ) -> Optional[FunctionValves]:
with get_db() as db:
try: try:
query = Function.update( function = db.get(Function, id)
**{"valves": valves}, function.valves = valves
updated_at=int(time.time()), function.updated_at = int(time.time())
).where(Function.id == id) db.commit()
query.execute() db.refresh(function)
return self.get_function_by_id(id)
function = Function.get(Function.id == id)
return FunctionValves(**model_to_dict(function))
except: except:
return None return None
def get_user_valves_by_id_and_user_id( def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
...@@ -201,6 +214,7 @@ class FunctionsTable: ...@@ -201,6 +214,7 @@ class FunctionsTable:
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict self, id: str, user_id: str, valves: dict
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
...@@ -222,39 +236,43 @@ class FunctionsTable: ...@@ -222,39 +236,43 @@ class FunctionsTable:
return None return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
with get_db() as db:
try: try:
query = Function.update( db.query(Function).filter_by(id=id).update(
{
**updated, **updated,
updated_at=int(time.time()), "updated_at": int(time.time()),
).where(Function.id == id) }
query.execute() )
db.commit()
function = Function.get(Function.id == id) return self.get_function_by_id(id)
return FunctionModel(**model_to_dict(function))
except: except:
return None return None
def deactivate_all_functions(self) -> Optional[bool]: def deactivate_all_functions(self) -> Optional[bool]:
with get_db() as db:
try: try:
query = Function.update( db.query(Function).update(
**{"is_active": False}, {
updated_at=int(time.time()), "is_active": False,
"updated_at": int(time.time()),
}
) )
db.commit()
query.execute()
return True return True
except: except:
return None return None
def delete_function_by_id(self, id: str) -> bool: def delete_function_by_id(self, id: str) -> bool:
try: with get_db() as db:
query = Function.delete().where((Function.id == id))
query.execute() # Remove the rows, return number of rows removed.
try:
db.query(Function).filter_by(id=id).delete()
return True return True
except: except:
return False return False
Functions = FunctionsTable(DB) Functions = FunctionsTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
from apps.webui.internal.db import DB from sqlalchemy import Column, String, BigInteger, Text
from apps.webui.models.chats import Chats
from apps.webui.internal.db import Base, get_db
import time import time
import uuid import uuid
...@@ -14,15 +13,14 @@ import uuid ...@@ -14,15 +13,14 @@ import uuid
#################### ####################
class Memory(Model): class Memory(Base):
id = CharField(unique=True) __tablename__ = "memory"
user_id = CharField()
content = TextField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta: id = Column(String, primary_key=True)
database = DB user_id = Column(String)
content = Column(Text)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class MemoryModel(BaseModel): class MemoryModel(BaseModel):
...@@ -32,6 +30,8 @@ class MemoryModel(BaseModel): ...@@ -32,6 +30,8 @@ class MemoryModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -39,15 +39,14 @@ class MemoryModel(BaseModel): ...@@ -39,15 +39,14 @@ class MemoryModel(BaseModel):
class MemoriesTable: class MemoriesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Memory])
def insert_new_memory( def insert_new_memory(
self, self,
user_id: str, user_id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
memory = MemoryModel( memory = MemoryModel(
...@@ -59,9 +58,12 @@ class MemoriesTable: ...@@ -59,9 +58,12 @@ class MemoriesTable:
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
result = Memory.create(**memory.model_dump()) result = Memory(**memory.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return memory return MemoryModel.model_validate(result)
else: else:
return None return None
...@@ -70,63 +72,71 @@ class MemoriesTable: ...@@ -70,63 +72,71 @@ class MemoriesTable:
id: str, id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db:
try: try:
memory = Memory.get(Memory.id == id) db.query(Memory).filter_by(id=id).update(
memory.content = content {"content": content, "updated_at": int(time.time())}
memory.updated_at = int(time.time()) )
memory.save() db.commit()
return MemoryModel(**model_to_dict(memory)) return self.get_memory_by_id(id)
except: except:
return None return None
def get_memories(self) -> List[MemoryModel]: def get_memories(self) -> List[MemoryModel]:
with get_db() as db:
try: try:
memories = Memory.select() memories = db.query(Memory).all()
return [MemoryModel(**model_to_dict(memory)) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
with get_db() as db:
try: try:
memories = Memory.select().where(Memory.user_id == user_id) memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel(**model_to_dict(memory)) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memory_by_id(self, id) -> Optional[MemoryModel]: def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
with get_db() as db:
try: try:
memory = Memory.get(Memory.id == id) memory = db.get(Memory, id)
return MemoryModel(**model_to_dict(memory)) return MemoryModel.model_validate(memory)
except: except:
return None return None
def delete_memory_by_id(self, id: str) -> bool: def delete_memory_by_id(self, id: str) -> bool:
try: with get_db() as db:
query = Memory.delete().where(Memory.id == id)
query.execute() # Remove the rows, return number of rows removed.
try:
db.query(Memory).filter_by(id=id).delete()
return True return True
except: except:
return False return False
def delete_memories_by_user_id(self, user_id: str) -> bool: def delete_memories_by_user_id(self, user_id: str) -> bool:
try: with get_db() as db:
query = Memory.delete().where(Memory.user_id == user_id)
query.execute()
try:
db.query(Memory).filter_by(user_id=user_id).delete()
return True return True
except: except:
return False return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: with get_db() as db:
query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
query.execute()
try:
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
return True return True
except: except:
return False return False
Memories = MemoriesTable(DB) Memories = MemoriesTable()
...@@ -2,13 +2,10 @@ import json ...@@ -2,13 +2,10 @@ import json
import logging import logging
from typing import Optional from typing import Optional
import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -46,38 +43,37 @@ class ModelMeta(BaseModel): ...@@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
pass pass
class Model(pw.Model): class Model(Base):
id = pw.TextField(unique=True) __tablename__ = "model"
id = Column(Text, primary_key=True)
""" """
The model's id as used in the API. If set to an existing model, it will override the model. The model's id as used in the API. If set to an existing model, it will override the model.
""" """
user_id = pw.TextField() user_id = Column(Text)
base_model_id = pw.TextField(null=True) base_model_id = Column(Text, nullable=True)
""" """
An optional pointer to the actual model that should be used when proxying requests. An optional pointer to the actual model that should be used when proxying requests.
""" """
name = pw.TextField() name = Column(Text)
""" """
The human-readable display name of the model. The human-readable display name of the model.
""" """
params = JSONField() params = Column(JSONField)
""" """
Holds a JSON encoded blob of parameters, see `ModelParams`. Holds a JSON encoded blob of parameters, see `ModelParams`.
""" """
meta = JSONField() meta = Column(JSONField)
""" """
Holds a JSON encoded blob of metadata, see `ModelMeta`. Holds a JSON encoded blob of metadata, see `ModelMeta`.
""" """
updated_at = BigIntegerField() updated_at = Column(BigInteger)
created_at = BigIntegerField() created_at = Column(BigInteger)
class Meta:
database = DB
class ModelModel(BaseModel): class ModelModel(BaseModel):
...@@ -92,6 +88,8 @@ class ModelModel(BaseModel): ...@@ -92,6 +88,8 @@ class ModelModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -115,12 +113,6 @@ class ModelForm(BaseModel): ...@@ -115,12 +113,6 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def __init__(
self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase,
):
self.db = db
self.db.create_tables([Model])
def insert_new_model( def insert_new_model(
self, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str
...@@ -134,10 +126,16 @@ class ModelsTable: ...@@ -134,10 +126,16 @@ class ModelsTable:
} }
) )
try: try:
result = Model.create(**model.model_dump())
with get_db() as db:
result = Model(**model.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return model return ModelModel.model_validate(result)
else: else:
return None return None
except Exception as e: except Exception as e:
...@@ -145,23 +143,29 @@ class ModelsTable: ...@@ -145,23 +143,29 @@ class ModelsTable:
return None return None
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()] with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
model = Model.get(Model.id == id) with get_db() as db:
return ModelModel(**model_to_dict(model))
model = db.get(Model, id)
return ModelModel.model_validate(model)
except: except:
return None return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try: try:
# update only the fields that are present in the model with get_db() as db:
query = Model.update(**model.model_dump()).where(Model.id == id)
query.execute()
model = Model.get(Model.id == id) # update only the fields that are present in the model
return ModelModel(**model_to_dict(model)) model = db.query(Model).get(id)
model.update(**model.model_dump())
db.commit()
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e: except Exception as e:
print(e) print(e)
...@@ -169,11 +173,12 @@ class ModelsTable: ...@@ -169,11 +173,12 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
query = Model.delete().where(Model.id == id) with get_db() as db:
query.execute()
db.query(Model).filter_by(id=id).delete()
return True return True
except: except:
return False return False
Models = ModelsTable(DB) Models = ModelsTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import * from typing import List, Optional
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time import time
from utils.utils import decode_token from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB from apps.webui.internal.db import Base, get_db
import json import json
...@@ -16,15 +13,14 @@ import json ...@@ -16,15 +13,14 @@ import json
#################### ####################
class Prompt(Model): class Prompt(Base):
command = CharField(unique=True) __tablename__ = "prompt"
user_id = CharField()
title = TextField()
content = TextField()
timestamp = BigIntegerField()
class Meta: command = Column(String, primary_key=True)
database = DB user_id = Column(String)
title = Column(Text)
content = Column(Text)
timestamp = Column(BigInteger)
class PromptModel(BaseModel): class PromptModel(BaseModel):
...@@ -34,6 +30,8 @@ class PromptModel(BaseModel): ...@@ -34,6 +30,8 @@ class PromptModel(BaseModel):
content: str content: str
timestamp: int # timestamp in epoch timestamp: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -48,10 +46,6 @@ class PromptForm(BaseModel): ...@@ -48,10 +46,6 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt( def insert_new_prompt(
self, user_id: str, form_data: PromptForm self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
...@@ -66,53 +60,58 @@ class PromptsTable: ...@@ -66,53 +60,58 @@ class PromptsTable:
) )
try: try:
result = Prompt.create(**prompt.model_dump()) with get_db() as db:
result = Prompt(**prompt.dict())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return prompt return PromptModel.model_validate(result)
else: else:
return None return None
except: except Exception as e:
return None return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try: try:
prompt = Prompt.get(Prompt.command == command) with get_db() as db:
return PromptModel(**model_to_dict(prompt))
prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt)
except: except:
return None return None
def get_prompts(self) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
with get_db() as db:
return [ return [
PromptModel(**model_to_dict(prompt)) PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
for prompt in Prompt.select()
# .limit(limit).offset(skip)
] ]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: try:
query = Prompt.update( with get_db() as db:
title=form_data.title,
content=form_data.content, prompt = db.query(Prompt).filter_by(command=command).first()
timestamp=int(time.time()), prompt.title = form_data.title
).where(Prompt.command == command) prompt.content = form_data.content
prompt.timestamp = int(time.time())
query.execute() db.commit()
return PromptModel.model_validate(prompt)
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
except: except:
return None return None
def delete_prompt_by_command(self, command: str) -> bool: def delete_prompt_by_command(self, command: str) -> bool:
try: try:
query = Prompt.delete().where((Prompt.command == command)) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed.
db.query(Prompt).filter_by(command=command).delete()
return True return True
except: except:
return False return False
Prompts = PromptsTable(DB) Prompts = PromptsTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional from typing import List, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json import json
import uuid import uuid
import time import time
import logging import logging
from apps.webui.internal.db import DB from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class Tag(Model): class Tag(Base):
id = CharField(unique=True) __tablename__ = "tag"
name = CharField()
user_id = CharField()
data = TextField(null=True)
class Meta: id = Column(String, primary_key=True)
database = DB name = Column(String)
user_id = Column(String)
data = Column(Text, nullable=True)
class ChatIdTag(Model): class ChatIdTag(Base):
id = CharField(unique=True) __tablename__ = "chatidtag"
tag_name = CharField()
chat_id = CharField()
user_id = CharField()
timestamp = BigIntegerField()
class Meta: id = Column(String, primary_key=True)
database = DB tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
class TagModel(BaseModel): class TagModel(BaseModel):
...@@ -47,6 +45,8 @@ class TagModel(BaseModel): ...@@ -47,6 +45,8 @@ class TagModel(BaseModel):
user_id: str user_id: str
data: Optional[str] = None data: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel): class ChatIdTagModel(BaseModel):
id: str id: str
...@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel): ...@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
user_id: str user_id: str
timestamp: int timestamp: int
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -75,17 +77,19 @@ class ChatTagsResponse(BaseModel): ...@@ -75,17 +77,19 @@ class ChatTagsResponse(BaseModel):
class TagTable: class TagTable:
def __init__(self, db):
self.db = db
db.create_tables([Tag, ChatIdTag])
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
result = Tag.create(**tag.model_dump()) result = Tag(**tag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return tag return TagModel.model_validate(result)
else: else:
return None return None
except Exception as e: except Exception as e:
...@@ -95,8 +99,9 @@ class TagTable: ...@@ -95,8 +99,9 @@ class TagTable:
self, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
tag = Tag.get(Tag.name == name, Tag.user_id == user_id) with get_db() as db:
return TagModel(**model_to_dict(tag)) tag = db.query(Tag).filter(name=name, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception as e: except Exception as e:
return None return None
...@@ -118,81 +123,108 @@ class TagTable: ...@@ -118,81 +123,108 @@ class TagTable:
} }
) )
try: try:
result = ChatIdTag.create(**chatIdTag.model_dump()) with get_db() as db:
result = ChatIdTag(**chatIdTag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return chatIdTag return ChatIdTagModel.model_validate(result)
else: else:
return None return None
except: except:
return None return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
with get_db() as db:
tag_names = [ tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name chat_id_tag.tag_name
for chat_id_tag in ChatIdTag.select() for chat_id_tag in (
.where(ChatIdTag.user_id == user_id) db.query(ChatIdTag)
.filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc()) .order_by(ChatIdTag.timestamp.desc())
.all()
)
] ]
return [ return [
TagModel(**model_to_dict(tag)) TagModel.model_validate(tag)
for tag in Tag.select() for tag in (
.where(Tag.user_id == user_id) db.query(Tag)
.where(Tag.name.in_(tag_names)) .filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
] ]
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> List[TagModel]: ) -> List[TagModel]:
with get_db() as db:
tag_names = [ tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name chat_id_tag.tag_name
for chat_id_tag in ChatIdTag.select() for chat_id_tag in (
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id)) db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
.order_by(ChatIdTag.timestamp.desc()) .order_by(ChatIdTag.timestamp.desc())
.all()
)
] ]
return [ return [
TagModel(**model_to_dict(tag)) TagModel.model_validate(tag)
for tag in Tag.select() for tag in (
.where(Tag.user_id == user_id) db.query(Tag)
.where(Tag.name.in_(tag_names)) .filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
] ]
def get_chat_ids_by_tag_name_and_user_id( def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> Optional[ChatIdTagModel]: ) -> List[ChatIdTagModel]:
with get_db() as db:
return [ return [
ChatIdTagModel(**model_to_dict(chat_id_tag)) ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in ChatIdTag.select() for chat_id_tag in (
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name)) db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc()) .order_by(ChatIdTag.timestamp.desc())
.all()
)
] ]
def count_chat_ids_by_tag_name_and_user_id( def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> int: ) -> int:
with get_db() as db:
return ( return (
ChatIdTag.select() db.query(ChatIdTag)
.where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)) .filter_by(tag_name=tag_name, user_id=user_id)
.count() .count()
) )
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try: try:
query = ChatIdTag.delete().where( with get_db() as db:
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id) res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.delete()
) )
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}") log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
query = Tag.delete().where( db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
(Tag.name == tag_name) & (Tag.user_id == user_id)
)
query.execute() # Remove the rows, return number of rows removed.
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
...@@ -202,21 +234,22 @@ class TagTable: ...@@ -202,21 +234,22 @@ class TagTable:
self, tag_name: str, chat_id: str, user_id: str self, tag_name: str, chat_id: str, user_id: str
) -> bool: ) -> bool:
try: try:
query = ChatIdTag.delete().where( with get_db() as db:
(ChatIdTag.tag_name == tag_name)
& (ChatIdTag.chat_id == chat_id) res = (
& (ChatIdTag.user_id == user_id) db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete()
) )
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}") log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
query = Tag.delete().where( db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
(Tag.name == tag_name) & (Tag.user_id == user_id)
)
query.execute() # Remove the rows, return number of rows removed.
return True return True
except Exception as e: except Exception as e:
...@@ -234,4 +267,4 @@ class TagTable: ...@@ -234,4 +267,4 @@ class TagTable:
return True return True
Tags = TagTable(DB) Tags = TagTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import * from typing import List, Optional
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time import time
import logging import logging
from apps.webui.internal.db import DB, JSONField from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
...@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class Tool(Model): class Tool(Base):
id = CharField(unique=True) __tablename__ = "tool"
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
valves = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta: id = Column(String, primary_key=True)
database = DB user_id = Column(String)
name = Column(Text)
content = Column(Text)
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ToolMeta(BaseModel): class ToolMeta(BaseModel):
...@@ -51,6 +50,8 @@ class ToolModel(BaseModel): ...@@ -51,6 +50,8 @@ class ToolModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -78,13 +79,13 @@ class ToolValves(BaseModel): ...@@ -78,13 +79,13 @@ class ToolValves(BaseModel):
class ToolsTable: class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool( def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict] self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
with get_db() as db:
tool = ToolModel( tool = ToolModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
...@@ -96,9 +97,12 @@ class ToolsTable: ...@@ -96,9 +97,12 @@ class ToolsTable:
) )
try: try:
result = Tool.create(**tool.model_dump()) result = Tool(**tool.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return tool return ToolModel.model_validate(result)
else: else:
return None return None
except Exception as e: except Exception as e:
...@@ -107,17 +111,22 @@ class ToolsTable: ...@@ -107,17 +111,22 @@ class ToolsTable:
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
tool = Tool.get(Tool.id == id) with get_db() as db:
return ToolModel(**model_to_dict(tool))
tool = db.get(Tool, id)
return ToolModel.model_validate(tool)
except: except:
return None return None
def get_tools(self) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] with get_db() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
tool = Tool.get(Tool.id == id) with get_db() as db:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
...@@ -125,14 +134,13 @@ class ToolsTable: ...@@ -125,14 +134,13 @@ class ToolsTable:
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try: try:
query = Tool.update( with get_db() as db:
**{"valves": valves},
updated_at=int(time.time()), db.query(Tool).filter_by(id=id).update(
).where(Tool.id == id) {"valves": valves, "updated_at": int(time.time())}
query.execute() )
db.commit()
tool = Tool.get(Tool.id == id) return self.get_tool_by_id(id)
return ToolValves(**model_to_dict(tool))
except: except:
return None return None
...@@ -179,25 +187,23 @@ class ToolsTable: ...@@ -179,25 +187,23 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
query = Tool.update( with get_db() as db:
**updated, tool = db.get(Tool, id)
updated_at=int(time.time()), tool.update(**updated)
).where(Tool.id == id) tool.updated_at = int(time.time())
query.execute() db.commit()
db.refresh(tool)
tool = Tool.get(Tool.id == id) return ToolModel.model_validate(tool)
return ToolModel(**model_to_dict(tool))
except: except:
return None return None
def delete_tool_by_id(self, id: str) -> bool: def delete_tool_by_id(self, id: str) -> bool:
try: try:
query = Tool.delete().where((Tool.id == id)) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed. db.query(Tool).filter_by(id=id).delete()
return True return True
except: except:
return False return False
Tools = ToolsTable(DB) Tools = ToolsTable()
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, parse_obj_as
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
...@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats ...@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
#################### ####################
class User(Model): class User(Base):
id = CharField(unique=True) __tablename__ = "user"
name = CharField()
email = CharField()
role = CharField()
profile_image_url = TextField()
last_active_at = BigIntegerField() id = Column(String, primary_key=True)
updated_at = BigIntegerField() name = Column(String)
created_at = BigIntegerField() email = Column(String)
role = Column(String)
profile_image_url = Column(Text)
api_key = CharField(null=True, unique=True) last_active_at = Column(BigInteger)
settings = JSONField(null=True) updated_at = Column(BigInteger)
info = JSONField(null=True) created_at = Column(BigInteger)
oauth_sub = TextField(null=True, unique=True) api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
class Meta: oauth_sub = Column(Text, unique=True)
database = DB
class UserSettings(BaseModel): class UserSettings(BaseModel):
...@@ -57,6 +57,8 @@ class UserModel(BaseModel): ...@@ -57,6 +57,8 @@ class UserModel(BaseModel):
oauth_sub: Optional[str] = None oauth_sub: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel): ...@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
class UsersTable: class UsersTable:
def __init__(self, db):
self.db = db
self.db.create_tables([User])
def insert_new_user( def insert_new_user(
self, self,
...@@ -89,6 +88,7 @@ class UsersTable: ...@@ -89,6 +88,7 @@ class UsersTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db:
user = UserModel( user = UserModel(
**{ **{
"id": id, "id": id,
...@@ -102,7 +102,10 @@ class UsersTable: ...@@ -102,7 +102,10 @@ class UsersTable:
"oauth_sub": oauth_sub, "oauth_sub": oauth_sub,
} }
) )
result = User.create(**user.model_dump()) result = User(**user.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result: if result:
return user return user
else: else:
...@@ -110,56 +113,67 @@ class UsersTable: ...@@ -110,56 +113,67 @@ class UsersTable:
def get_user_by_id(self, id: str) -> Optional[UserModel]: def get_user_by_id(self, id: str) -> Optional[UserModel]:
try: try:
user = User.get(User.id == id) with get_db() as db:
return UserModel(**model_to_dict(user)) user = db.query(User).filter_by(id=id).first()
except: return UserModel.model_validate(user)
except Exception as e:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try: try:
user = User.get(User.api_key == api_key) with get_db() as db:
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except: except:
return None return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
user = User.get(User.email == email) with get_db() as db:
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except: except:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try: try:
user = User.get(User.oauth_sub == sub) with get_db() as db:
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except: except:
return None return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ with get_db() as db:
UserModel(**model_to_dict(user)) users = (
for user in User.select() db.query(User)
# .limit(limit).offset(skip) # .offset(skip).limit(limit)
] .all()
)
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
return User.select().count() with get_db() as db:
return db.query(User).count()
def get_first_user(self) -> UserModel: def get_first_user(self) -> UserModel:
try: try:
user = User.select().order_by(User.created_at).first() with get_db() as db:
return UserModel(**model_to_dict(user)) user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user)
except: except:
return None return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try: try:
query = User.update(role=role).where(User.id == id) with get_db() as db:
query.execute() db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = User.get(User.id == id) user = db.query(User).filter_by(id=id).first()
return UserModel(**model_to_dict(user)) return UserModel.model_validate(user)
except: except:
return None return None
...@@ -167,23 +181,28 @@ class UsersTable: ...@@ -167,23 +181,28 @@ class UsersTable:
self, id: str, profile_image_url: str self, id: str, profile_image_url: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
query = User.update(profile_image_url=profile_image_url).where( with get_db() as db:
User.id == id db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url}
) )
query.execute() db.commit()
user = User.get(User.id == id) user = db.query(User).filter_by(id=id).first()
return UserModel(**model_to_dict(user)) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try: try:
query = User.update(last_active_at=int(time.time())).where(User.id == id) with get_db() as db:
query.execute()
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)
db.commit()
user = User.get(User.id == id) user = db.query(User).filter_by(id=id).first()
return UserModel(**model_to_dict(user)) return UserModel.model_validate(user)
except: except:
return None return None
...@@ -191,22 +210,24 @@ class UsersTable: ...@@ -191,22 +210,24 @@ class UsersTable:
self, id: str, oauth_sub: str self, id: str, oauth_sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id) with get_db() as db:
query.execute() db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
user = User.get(User.id == id) user = db.query(User).filter_by(id=id).first()
return UserModel(**model_to_dict(user)) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try: try:
query = User.update(**updated).where(User.id == id) with get_db() as db:
query.execute() db.query(User).filter_by(id=id).update(updated)
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user)) user = db.query(User).filter_by(id=id).first()
except: return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception as e:
return None return None
def delete_user_by_id(self, id: str) -> bool: def delete_user_by_id(self, id: str) -> bool:
...@@ -215,9 +236,10 @@ class UsersTable: ...@@ -215,9 +236,10 @@ class UsersTable:
result = Chats.delete_chats_by_user_id(id) result = Chats.delete_chats_by_user_id(id)
if result: if result:
with get_db() as db:
# Delete User # Delete User
query = User.delete().where(User.id == id) db.query(User).filter_by(id=id).delete()
query.execute() # Remove the rows, return number of rows removed. db.commit()
return True return True
else: else:
...@@ -227,19 +249,20 @@ class UsersTable: ...@@ -227,19 +249,20 @@ class UsersTable:
def update_user_api_key_by_id(self, id: str, api_key: str) -> str: def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try: try:
query = User.update(api_key=api_key).where(User.id == id) with get_db() as db:
result = query.execute() result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]: def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try: try:
user = User.get(User.id == id) with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return user.api_key return user.api_key
except: except Exception as e:
return None return None
Users = UsersTable(DB) Users = UsersTable()
...@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user ...@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse]) @router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id( async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50 user_id: str,
user=Depends(get_admin_user),
skip: int = 0,
limit: int = 50,
): ):
return Chats.get_chat_list_by_user_id( return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit user_id, include_archived=True, skip=skip, limit=limit
...@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)): ...@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=List[ChatResponse]) @router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)): async def get_user_archived_chats(user=Depends(get_verified_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id) for chat in Chats.get_archived_chats_by_user_id(user.id)
......
...@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_ ...@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
@router.post("/doc/update", response_model=Optional[DocumentResponse]) @router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name( async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user) name: str,
form_data: DocumentUpdateForm,
user=Depends(get_admin_user),
): ):
doc = Documents.update_doc_by_name(name, form_data) doc = Documents.update_doc_by_name(name, form_data)
if doc: if doc:
......
...@@ -50,10 +50,7 @@ router = APIRouter() ...@@ -50,10 +50,7 @@ router = APIRouter()
@router.post("/") @router.post("/")
def upload_file( def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
try: try:
unsanitized_filename = file.filename unsanitized_filename = file.filename
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment