Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
df09d083
Commit
df09d083
authored
Jun 18, 2024
by
Jonathan Rohde
Browse files
feat(sqlalchemy): Replace peewee with sqlalchemy
parent
8dac2a21
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
860 additions
and
804 deletions
+860
-804
.github/workflows/integration-test.yml
.github/workflows/integration-test.yml
+2
-2
backend/alembic.ini
backend/alembic.ini
+114
-0
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+5
-2
backend/apps/openai/main.py
backend/apps/openai/main.py
+3
-1
backend/apps/socket/main.py
backend/apps/socket/main.py
+3
-1
backend/apps/webui/internal/db.py
backend/apps/webui/internal/db.py
+40
-26
backend/apps/webui/internal/wrappers.py
backend/apps/webui/internal/wrappers.py
+0
-72
backend/apps/webui/main.py
backend/apps/webui/main.py
+3
-3
backend/apps/webui/models/auths.py
backend/apps/webui/models/auths.py
+40
-37
backend/apps/webui/models/chats.py
backend/apps/webui/models/chats.py
+139
-162
backend/apps/webui/models/documents.py
backend/apps/webui/models/documents.py
+48
-56
backend/apps/webui/models/files.py
backend/apps/webui/models/files.py
+30
-32
backend/apps/webui/models/functions.py
backend/apps/webui/models/functions.py
+36
-38
backend/apps/webui/models/memories.py
backend/apps/webui/models/memories.py
+43
-44
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+38
-41
backend/apps/webui/models/prompts.py
backend/apps/webui/models/prompts.py
+38
-48
backend/apps/webui/models/tags.py
backend/apps/webui/models/tags.py
+109
-89
backend/apps/webui/models/tools.py
backend/apps/webui/models/tools.py
+37
-41
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+92
-83
backend/apps/webui/routers/auths.py
backend/apps/webui/routers/auths.py
+40
-26
No files found.
.github/workflows/integration-test.yml
View file @
df09d083
...
...
@@ -171,7 +171,7 @@ jobs:
fi
# Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health
/db
)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check"
exit 1
...
...
@@ -183,7 +183,7 @@ jobs:
cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health
/db
)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1
...
...
backend/alembic.ini
0 → 100644
View file @
df09d083
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location
=
migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path
=
.
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator
=
os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url
=
REPLACE_WITH_DATABASE_URL
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys
=
root,sqlalchemy,alembic
[handlers]
keys
=
console
[formatters]
keys
=
generic
[logger_root]
level
=
WARN
handlers
=
console
qualname
=
[logger_sqlalchemy]
level
=
WARN
handlers
=
qualname
=
sqlalchemy.engine
[logger_alembic]
level
=
INFO
handlers
=
qualname
=
alembic
[handler_console]
class
=
StreamHandler
args
=
(sys.stderr,)
level
=
NOTSET
formatter
=
generic
[formatter_generic]
format
=
%(levelname)-5.5s [%(name)s] %(message)s
datefmt
=
%H:%M:%S
backend/apps/ollama/main.py
View file @
df09d083
...
...
@@ -31,6 +31,7 @@ from typing import Optional, List, Union
from
starlette.background
import
BackgroundTask
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
constants
import
ERROR_MESSAGES
...
...
@@ -711,6 +712,7 @@ async def generate_chat_completion(
form_data
:
GenerateChatCompletionForm
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
db
=
Depends
(
get_db
),
):
log
.
debug
(
...
...
@@ -724,7 +726,7 @@ async def generate_chat_completion(
}
model_id
=
form_data
.
model
model_info
=
Models
.
get_model_by_id
(
model_id
)
model_info
=
Models
.
get_model_by_id
(
db
,
model_id
)
if
model_info
:
if
model_info
.
base_model_id
:
...
...
@@ -883,6 +885,7 @@ async def generate_openai_chat_completion(
form_data
:
dict
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
db
=
Depends
(
get_db
),
):
form_data
=
OpenAIChatCompletionForm
(
**
form_data
)
...
...
@@ -891,7 +894,7 @@ async def generate_openai_chat_completion(
}
model_id
=
form_data
.
model
model_info
=
Models
.
get_model_by_id
(
model_id
)
model_info
=
Models
.
get_model_by_id
(
db
,
model_id
)
if
model_info
:
if
model_info
.
base_model_id
:
...
...
backend/apps/openai/main.py
View file @
df09d083
...
...
@@ -11,6 +11,7 @@ import logging
from
pydantic
import
BaseModel
from
starlette.background
import
BackgroundTask
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
constants
import
ERROR_MESSAGES
...
...
@@ -353,12 +354,13 @@ async def generate_chat_completion(
form_data
:
dict
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
db
=
Depends
(
get_db
),
):
idx
=
0
payload
=
{
**
form_data
}
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
model_info
=
Models
.
get_model_by_id
(
db
,
model_id
)
if
model_info
:
if
model_info
.
base_model_id
:
...
...
backend/apps/socket/main.py
View file @
df09d083
...
...
@@ -24,7 +24,9 @@ async def connect(sid, environ, auth):
data
=
decode_token
(
auth
[
"token"
])
if
data
is
not
None
and
"id"
in
data
:
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
from
apps.webui.internal.db
import
SessionLocal
user
=
Users
.
get_user_by_id
(
SessionLocal
(),
data
[
"id"
])
if
user
:
SESSION_POOL
[
sid
]
=
user
.
id
...
...
backend/apps/webui/internal/db.py
View file @
df09d083
import
os
import
logging
import
json
from
typing
import
Optional
,
Any
from
typing_extensions
import
Self
from
peewee
import
*
from
peewee_migrate
import
Router
from
sqlalchemy
import
create_engine
,
types
,
Dialect
from
sqlalchemy.ext.declarative
import
declarative_base
from
sqlalchemy.orm
import
sessionmaker
from
sqlalchemy.sql.type_api
import
_T
from
apps.webui.internal.wrappers
import
register_connection
from
config
import
SRC_LOG_LEVELS
,
DATA_DIR
,
DATABASE_URL
,
BACKEND_DIR
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
class
JSONField
(
TextField
):
class
JSONField
(
types
.
TypeDecorator
):
impl
=
types
.
Text
cache_ok
=
True
def
process_bind_param
(
self
,
value
:
Optional
[
_T
],
dialect
:
Dialect
)
->
Any
:
return
json
.
dumps
(
value
)
def
process_result_value
(
self
,
value
:
Optional
[
_T
],
dialect
:
Dialect
)
->
Any
:
if
value
is
not
None
:
return
json
.
loads
(
value
)
def
copy
(
self
,
**
kw
:
Any
)
->
Self
:
return
JSONField
(
self
.
impl
.
length
)
def
db_value
(
self
,
value
):
return
json
.
dumps
(
value
)
...
...
@@ -29,26 +45,24 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
else
:
pass
SQLALCHEMY_DATABASE_URL
=
DATABASE_URL
if
"sqlite"
in
SQLALCHEMY_DATABASE_URL
:
engine
=
create_engine
(
SQLALCHEMY_DATABASE_URL
,
connect_args
=
{
"check_same_thread"
:
False
}
)
else
:
engine
=
create_engine
(
SQLALCHEMY_DATABASE_URL
,
pool_pre_ping
=
True
)
SessionLocal
=
sessionmaker
(
autocommit
=
False
,
autoflush
=
False
,
bind
=
engine
)
Base
=
declarative_base
()
# The `register_connection` function encapsulates the logic for setting up
# the database connection based on the connection string, while `connect`
# is a Peewee-specific method to manage the connection state and avoid errors
# when a connection is already open.
try
:
DB
=
register_connection
(
DATABASE_URL
)
log
.
info
(
f
"Connected to a
{
DB
.
__class__
.
__name__
}
database."
)
except
Exception
as
e
:
log
.
error
(
f
"Failed to initialize the database connection:
{
e
}
"
)
raise
router
=
Router
(
DB
,
migrate_dir
=
BACKEND_DIR
/
"apps"
/
"webui"
/
"internal"
/
"migrations"
,
logger
=
log
,
)
router
.
run
()
try
:
DB
.
connect
(
reuse_if_open
=
True
)
except
OperationalError
as
e
:
log
.
info
(
f
"Failed to connect to database again due to:
{
e
}
"
)
pass
def
get_db
():
db
=
SessionLocal
()
try
:
yield
db
db
.
commit
()
except
Exception
as
e
:
db
.
rollback
()
raise
e
finally
:
db
.
close
()
backend/apps/webui/internal/wrappers.py
deleted
100644 → 0
View file @
8dac2a21
from
contextvars
import
ContextVar
from
peewee
import
*
from
peewee
import
PostgresqlDatabase
,
InterfaceError
as
PeeWeeInterfaceError
import
logging
from
playhouse.db_url
import
connect
,
parse
from
playhouse.shortcuts
import
ReconnectMixin
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
db_state_default
=
{
"closed"
:
None
,
"conn"
:
None
,
"ctx"
:
None
,
"transactions"
:
None
}
db_state
=
ContextVar
(
"db_state"
,
default
=
db_state_default
.
copy
())
class
PeeweeConnectionState
(
object
):
def
__init__
(
self
,
**
kwargs
):
super
().
__setattr__
(
"_state"
,
db_state
)
super
().
__init__
(
**
kwargs
)
def
__setattr__
(
self
,
name
,
value
):
self
.
_state
.
get
()[
name
]
=
value
def
__getattr__
(
self
,
name
):
value
=
self
.
_state
.
get
()[
name
]
return
value
class
CustomReconnectMixin
(
ReconnectMixin
):
reconnect_errors
=
(
# psycopg2
(
OperationalError
,
"termin"
),
(
InterfaceError
,
"closed"
),
# peewee
(
PeeWeeInterfaceError
,
"closed"
),
)
class
ReconnectingPostgresqlDatabase
(
CustomReconnectMixin
,
PostgresqlDatabase
):
pass
def
register_connection
(
db_url
):
db
=
connect
(
db_url
)
if
isinstance
(
db
,
PostgresqlDatabase
):
# Enable autoconnect for SQLite databases, managed by Peewee
db
.
autoconnect
=
True
db
.
reuse_if_open
=
True
log
.
info
(
"Connected to PostgreSQL database"
)
# Get the connection details
connection
=
parse
(
db_url
)
# Use our custom database class that supports reconnection
db
=
ReconnectingPostgresqlDatabase
(
connection
[
"database"
],
user
=
connection
[
"user"
],
password
=
connection
[
"password"
],
host
=
connection
[
"host"
],
port
=
connection
[
"port"
],
)
db
.
connect
(
reuse_if_open
=
True
)
elif
isinstance
(
db
,
SqliteDatabase
):
# Enable autoconnect for SQLite databases, managed by Peewee
db
.
autoconnect
=
True
db
.
reuse_if_open
=
True
log
.
info
(
"Connected to SQLite database"
)
else
:
raise
ValueError
(
"Unsupported database connection"
)
return
db
backend/apps/webui/main.py
View file @
df09d083
...
...
@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
from
fastapi.responses
import
StreamingResponse
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
sqlalchemy.orm
import
Session
from
apps.webui.routers
import
(
auths
,
users
,
...
...
@@ -114,8 +114,8 @@ async def get_status():
}
async
def
get_pipe_models
():
pipes
=
Functions
.
get_functions_by_type
(
"pipe"
,
active_only
=
True
)
async
def
get_pipe_models
(
db
:
Session
):
pipes
=
Functions
.
get_functions_by_type
(
db
,
"pipe"
,
active_only
=
True
)
pipe_models
=
[]
for
pipe
in
pipes
:
...
...
backend/apps/webui/models/auths.py
View file @
df09d083
from
pydantic
import
BaseModel
from
typing
import
List
,
Union
,
Optional
import
time
from
typing
import
Optional
import
uuid
import
logging
from
peewee
import
*
from
sqlalchemy
import
String
,
Column
,
Boolean
from
sqlalchemy.orm
import
Session
from
apps.webui.models.users
import
UserModel
,
Users
from
utils.utils
import
verify_password
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
from
config
import
SRC_LOG_LEVELS
...
...
@@ -20,14 +20,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Auth
(
Model
):
id
=
CharField
(
unique
=
True
)
email
=
CharField
()
password
=
TextField
()
active
=
BooleanField
()
class
Auth
(
Base
):
__tablename__
=
"auth"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
email
=
Column
(
String
)
password
=
Column
(
String
)
active
=
Column
(
Boolean
)
class
AuthModel
(
BaseModel
):
...
...
@@ -94,12 +93,10 @@ class AddUserForm(SignupForm):
class
AuthsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Auth
])
def
insert_new_auth
(
self
,
db
:
Session
,
email
:
str
,
password
:
str
,
name
:
str
,
...
...
@@ -114,24 +111,30 @@ class AuthsTable:
auth
=
AuthModel
(
**
{
"id"
:
id
,
"email"
:
email
,
"password"
:
password
,
"active"
:
True
}
)
result
=
Auth
.
create
(
**
auth
.
model_dump
())
result
=
Auth
(
**
auth
.
model_dump
())
db
.
add
(
result
)
user
=
Users
.
insert_new_user
(
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
db
,
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
and
user
:
return
user
else
:
return
None
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
def
authenticate_user
(
self
,
db
:
Session
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user:
{
email
}
"
)
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
auth
=
db
.
query
(
Auth
).
filter_by
(
email
=
email
,
active
=
True
)
.
first
()
if
auth
:
if
verify_password
(
password
,
auth
.
password
):
user
=
Users
.
get_user_by_id
(
auth
.
id
)
user
=
Users
.
get_user_by_id
(
db
,
auth
.
id
)
return
user
else
:
return
None
...
...
@@ -140,55 +143,55 @@ class AuthsTable:
except
:
return
None
def
authenticate_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
def
authenticate_user_by_api_key
(
self
,
db
:
Session
,
api_key
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user_by_api_key:
{
api_key
}
"
)
# if no api_key, return None
if
not
api_key
:
return
None
try
:
user
=
Users
.
get_user_by_api_key
(
api_key
)
user
=
Users
.
get_user_by_api_key
(
db
,
api_key
)
return
user
if
user
else
None
except
:
return
False
def
authenticate_user_by_trusted_header
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
def
authenticate_user_by_trusted_header
(
self
,
db
:
Session
,
email
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user_by_trusted_header:
{
email
}
"
)
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
auth
=
db
.
query
(
Auth
).
filter
(
email
=
email
,
active
=
True
)
.
first
()
if
auth
:
user
=
Users
.
get_user_by_id
(
auth
.
id
)
return
user
except
:
return
None
def
update_user_password_by_id
(
self
,
id
:
str
,
new_password
:
str
)
->
bool
:
def
update_user_password_by_id
(
self
,
db
:
Session
,
id
:
str
,
new_password
:
str
)
->
bool
:
try
:
query
=
Auth
.
update
(
password
=
new_password
).
where
(
Auth
.
id
==
id
)
result
=
query
.
execute
()
result
=
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"password"
:
new_password
})
return
True
if
result
==
1
else
False
except
:
return
False
def
update_email_by_id
(
self
,
id
:
str
,
email
:
str
)
->
bool
:
def
update_email_by_id
(
self
,
db
:
Session
,
id
:
str
,
email
:
str
)
->
bool
:
try
:
query
=
Auth
.
update
(
email
=
email
).
where
(
Auth
.
id
==
id
)
result
=
query
.
execute
()
result
=
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"email"
:
email
})
return
True
if
result
==
1
else
False
except
:
return
False
def
delete_auth_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_auth_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
try
:
# Delete User
result
=
Users
.
delete_user_by_id
(
id
)
result
=
Users
.
delete_user_by_id
(
db
,
id
)
if
result
:
# Delete Auth
query
=
Auth
.
delete
().
where
(
Auth
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
delete
()
return
True
else
:
...
...
@@ -197,4 +200,4 @@ class AuthsTable:
return
False
Auths
=
AuthsTable
(
DB
)
Auths
=
AuthsTable
()
backend/apps/webui/models/chats.py
View file @
df09d083
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
import
json
import
uuid
import
time
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Boolean
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
####################
# Chat DB Schema
####################
class
Chat
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
title
=
TextField
()
chat
=
TextField
()
# Save Chat JSON as Text
class
Chat
(
Base
):
__tablename__
=
"chat"
created_at
=
BigIntegerField
()
updated_at
=
BigIntegerField
()
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
title
=
Column
(
String
)
chat
=
Column
(
String
)
# Save Chat JSON as Text
share_id
=
CharField
(
null
=
True
,
unique
=
True
)
archived
=
BooleanField
(
default
=
False
)
created_at
=
Column
(
BigInteger
)
updated_at
=
Column
(
BigInteger
)
class
Meta
:
database
=
DB
share_id
=
Column
(
String
,
unique
=
True
,
nullable
=
True
)
archived
=
Column
(
Boolean
,
default
=
False
)
class
ChatModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
id
:
str
user_id
:
str
title
:
str
...
...
@@ -75,11 +78,10 @@ class ChatTitleIdResponse(BaseModel):
class
ChatTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
db
.
create_tables
([
Chat
])
def
insert_new_chat
(
self
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
def
insert_new_chat
(
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
id
=
str
(
uuid
.
uuid4
())
chat
=
ChatModel
(
**
{
...
...
@@ -94,29 +96,36 @@ class ChatTable:
}
)
result
=
Chat
.
create
(
**
chat
.
model_dump
())
return
chat
if
result
else
None
result
=
Chat
(
**
chat
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
return
ChatModel
.
model_validate
(
result
)
if
result
else
None
def
update_chat_by_id
(
self
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
def
update_chat_by_id
(
self
,
db
:
Session
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
try
:
query
=
Chat
.
update
(
chat
=
json
.
dumps
(
chat
),
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New C
hat
"
,
updated_at
=
int
(
time
.
time
())
,
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
)
)
db
.
query
(
Chat
).
filter_by
(
id
=
id
)
.
update
(
{
"chat"
:
json
.
dumps
(
c
hat
)
,
"title"
:
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
,
"updated_at"
:
int
(
time
.
time
()),
}
)
return
self
.
get_chat_by_id
(
db
,
id
)
except
:
return
None
def
insert_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
def
insert_shared_chat_by_chat_id
(
self
,
db
:
Session
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
# Get the existing chat to share
chat
=
Chat
.
get
(
Chat
.
id
==
chat_id
)
chat
=
db
.
get
(
Chat
,
chat_id
)
# Check if the chat is already shared
if
chat
.
share_id
:
return
self
.
get_chat_by_id_and_user_id
(
chat
.
share_id
,
"shared"
)
return
self
.
get_chat_by_id_and_user_id
(
db
,
chat
.
share_id
,
"shared"
)
# Create a new chat with the same data, but with a new ID
shared_chat
=
ChatModel
(
**
{
...
...
@@ -128,228 +137,196 @@ class ChatTable:
"updated_at"
:
int
(
time
.
time
()),
}
)
shared_result
=
Chat
.
create
(
**
shared_chat
.
model_dump
())
shared_result
=
Chat
(
**
shared_chat
.
model_dump
())
db
.
add
(
shared_result
)
db
.
commit
()
db
.
refresh
(
shared_result
)
# Update the original chat with the share_id
result
=
(
Chat
.
update
(
share_id
=
shared_chat
.
id
).
where
(
Chat
.
id
==
chat_id
).
execute
(
)
db
.
query
(
Chat
).
filter_by
(
id
=
chat_id
)
.
update
(
{
"
share_id
"
:
shared_chat
.
id
}
)
)
return
shared_chat
if
(
shared_result
and
result
)
else
None
def
update_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
def
update_shared_chat_by_chat_id
(
self
,
db
:
Session
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
print
(
"update_shared_chat_by_id"
)
chat
=
Chat
.
get
(
Chat
.
id
==
chat_id
)
chat
=
db
.
get
(
Chat
,
chat_id
)
print
(
chat
)
query
=
Chat
.
update
(
title
=
chat
.
title
,
chat
=
chat
.
chat
,
).
where
(
Chat
.
id
==
chat
.
share_id
)
db
.
query
(
Chat
).
filter_by
(
id
=
chat
.
share_id
).
update
(
{
"title"
:
chat
.
title
,
"chat"
:
chat
.
chat
}
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
chat
.
share_id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
return
self
.
get_chat_by_id
(
db
,
chat
.
share_id
)
except
:
return
None
def
delete_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
bool
:
def
delete_shared_chat_by_chat_id
(
self
,
db
:
Session
,
chat_id
:
str
)
->
bool
:
try
:
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
==
f
"shared-
{
chat_id
}
"
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Chat
).
filter_by
(
user_id
=
f
"shared-
{
chat_id
}
"
).
delete
()
return
True
except
:
return
False
def
update_chat_share_id_by_id
(
self
,
id
:
str
,
share_id
:
Optional
[
str
]
self
,
db
:
Session
,
id
:
str
,
share_id
:
Optional
[
str
]
)
->
Optional
[
ChatModel
]:
try
:
query
=
Chat
.
update
(
share_id
=
share_id
,
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
update
({
"share_id"
:
share_id
})
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
return
self
.
get_chat_by_id
(
db
,
id
)
except
:
return
None
def
toggle_chat_archive_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
def
toggle_chat_archive_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
self
.
get_chat_by_id
(
id
)
query
=
Chat
.
update
(
archived
=
(
not
chat
.
archived
),
).
where
(
Chat
.
id
==
id
)
chat
=
self
.
get_chat_by_id
(
db
,
id
)
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
update
({
"archived"
:
not
chat
.
archived
})
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
return
self
.
get_chat_by_id
(
db
,
id
)
except
:
return
None
def
archive_all_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
def
archive_all_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
bool
:
try
:
chats
=
self
.
get_chats_by_user_id
(
user_id
)
for
chat
in
chats
:
query
=
Chat
.
update
(
archived
=
True
,
).
where
(
Chat
.
id
==
chat
.
id
)
query
.
execute
()
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
update
({
"archived"
:
True
})
return
True
except
:
return
False
def
get_archived_chat_list_by_user_id
(
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
self
,
db
:
Session
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
True
)
.
where
(
Chat
.
user_id
==
user_id
)
all_chats
=
(
db
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
# .limit(limit).offset(skip)
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_list_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
,
include_archived
:
bool
=
False
,
skip
:
int
=
0
,
limit
:
int
=
50
,
)
->
List
[
ChatModel
]:
if
include_archived
:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
else
:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
False
)
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
query
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
)
if
not
include_archived
:
query
=
query
.
filter_by
(
archived
=
False
)
all_chats
=
(
query
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit).offset(skip)
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_list_by_chat_ids
(
self
,
chat_ids
:
List
[
str
],
skip
:
int
=
0
,
limit
:
int
=
50
self
,
db
:
Session
,
chat_ids
:
List
[
str
],
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
False
)
.
where
(
Chat
.
id
.
in_
(
chat_ids
))
all_chats
=
(
db
.
query
(
Chat
)
.
filter
(
Chat
.
id
.
in_
(
chat_ids
))
.
filter_by
(
archived
=
False
)
.
order_by
(
Chat
.
updated_at
.
desc
())
]
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
def
get_chat_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_
to_dict
(
chat
)
)
chat
=
db
.
get
(
Chat
,
id
)
return
ChatModel
.
model_
validate
(
chat
)
except
:
return
None
def
get_chat_by_share_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
def
get_chat_by_share_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
Chat
.
get
(
Chat
.
share_id
==
id
)
chat
=
db
.
query
(
Chat
).
filter_by
(
share_id
=
id
).
first
(
)
if
chat
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
return
self
.
get_chat_by_id
(
db
,
id
)
else
:
return
None
except
:
except
Exception
as
e
:
return
None
def
get_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
def
get_chat_by_id_and_user_id
(
self
,
db
:
Session
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
,
Chat
.
user_id
==
user_id
)
return
ChatModel
(
**
model_
to_dict
(
chat
)
)
chat
=
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
)
.
first
()
return
ChatModel
.
model_
validate
(
chat
)
except
:
return
None
def
get_chats
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
().
order_by
(
Chat
.
updated_at
.
desc
())
def
get_chats
(
self
,
db
:
Session
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
all_chats
=
(
db
.
query
(
Chat
)
# .limit(limit).offset(skip)
]
def
get_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit).offset(skip)
]
def
get_archived_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
True
)
.
where
(
Chat
.
user_id
==
user_id
)
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
List
[
ChatModel
]:
all_chats
=
(
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
order_by
(
Chat
.
updated_at
.
desc
())
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_archived_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
List
[
ChatModel
]:
all_chats
=
(
db
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
]
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
delete_chat_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_chat_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
try
:
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
delete
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
return
True
and
self
.
delete_shared_chat_by_chat_id
(
db
,
id
)
except
:
return
False
def
delete_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
def
delete_chat_by_id_and_user_id
(
self
,
db
:
Session
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
)
&
(
Chat
.
user_id
==
user_id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
return
True
and
self
.
delete_shared_chat_by_chat_id
(
db
,
id
)
except
:
return
False
def
delete_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
def
delete_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
bool
:
try
:
self
.
delete_shared_chats_by_user_id
(
user_id
)
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
==
user_id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
self
.
delete_shared_chats_by_user_id
(
db
,
user_id
)
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
def
delete_shared_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
def
delete_shared_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
bool
:
try
:
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
Chat
.
select
().
where
(
Chat
.
user_id
==
user_id
)
]
chats_by_user
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
all
()
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
chats_by_user
]
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
<<
shared_chat_ids
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Chat
).
filter
(
Chat
.
user_id
.
in_
(
shared_chat_ids
)).
delete
()
return
True
except
:
return
False
Chats
=
ChatTable
(
DB
)
Chats
=
ChatTable
()
backend/apps/webui/models/documents.py
View file @
df09d083
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
import
logging
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
import
json
...
...
@@ -22,20 +20,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Document
(
Model
):
collection_name
=
CharField
(
unique
=
True
)
name
=
CharField
(
unique
=
True
)
title
=
TextField
()
filename
=
TextField
()
content
=
TextField
(
null
=
True
)
user_id
=
CharField
()
timestamp
=
BigIntegerField
()
class
Document
(
Base
):
__tablename__
=
"document"
class
Meta
:
database
=
DB
collection_name
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
,
unique
=
True
)
title
=
Column
(
String
)
filename
=
Column
(
String
)
content
=
Column
(
String
,
nullable
=
True
)
user_id
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
DocumentModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
collection_name
:
str
name
:
str
title
:
str
...
...
@@ -72,12 +71,9 @@ class DocumentForm(DocumentUpdateForm):
class
DocumentsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Document
])
def
insert_new_doc
(
self
,
user_id
:
str
,
form_data
:
DocumentForm
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
DocumentForm
)
->
Optional
[
DocumentModel
]:
document
=
DocumentModel
(
**
{
...
...
@@ -88,73 +84,69 @@ class DocumentsTable:
)
try
:
result
=
Document
.
create
(
**
document
.
model_dump
())
result
=
Document
(
**
document
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
d
ocument
return
D
ocument
Model
.
model_validate
(
result
)
else
:
return
None
except
:
return
None
def
get_doc_by_name
(
self
,
name
:
str
)
->
Optional
[
DocumentModel
]:
def
get_doc_by_name
(
self
,
db
:
Session
,
name
:
str
)
->
Optional
[
DocumentModel
]:
try
:
document
=
Document
.
get
(
Document
.
name
==
name
)
return
DocumentModel
(
**
model_
to_dict
(
document
)
)
document
=
db
.
query
(
Document
).
filter_by
(
name
=
name
).
first
(
)
return
DocumentModel
.
model_
validate
(
document
)
if
document
else
None
except
:
return
None
def
get_docs
(
self
)
->
List
[
DocumentModel
]:
return
[
DocumentModel
(
**
model_to_dict
(
doc
))
for
doc
in
Document
.
select
()
# .limit(limit).offset(skip)
]
def
get_docs
(
self
,
db
:
Session
)
->
List
[
DocumentModel
]:
return
[
DocumentModel
.
model_validate
(
doc
)
for
doc
in
db
.
query
(
Document
).
all
()]
def
update_doc_by_name
(
self
,
name
:
str
,
form_data
:
DocumentUpdateForm
self
,
db
:
Session
,
name
:
str
,
form_data
:
DocumentUpdateForm
)
->
Optional
[
DocumentModel
]:
try
:
query
=
Document
.
update
(
title
=
form_data
.
title
,
name
=
form_data
.
name
,
timestamp
=
int
(
time
.
time
()),
).
where
(
Document
.
name
==
name
)
query
.
execute
()
doc
=
Document
.
get
(
Document
.
name
==
form_data
.
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
db
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
{
"title"
:
form_data
.
title
,
"name"
:
form_data
.
name
,
"timestamp"
:
int
(
time
.
time
()),
}
)
return
self
.
get_doc_by_name
(
db
,
form_data
.
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
return
None
def
update_doc_content_by_name
(
self
,
name
:
str
,
updated
:
dict
self
,
db
:
Session
,
name
:
str
,
updated
:
dict
)
->
Optional
[
DocumentModel
]:
try
:
doc
=
self
.
get_doc_by_name
(
name
)
doc
=
self
.
get_doc_by_name
(
db
,
name
)
doc_content
=
json
.
loads
(
doc
.
content
if
doc
.
content
else
"{}"
)
doc_content
=
{
**
doc_content
,
**
updated
}
query
=
Document
.
update
(
content
=
json
.
dumps
(
doc_content
),
timestamp
=
int
(
time
.
time
()),
).
where
(
Document
.
name
==
name
)
query
.
execute
()
db
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
{
"content"
:
json
.
dumps
(
doc_content
),
"timestamp"
:
int
(
time
.
time
()),
}
)
doc
=
Document
.
get
(
Document
.
name
==
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
return
self
.
get_doc_by_name
(
db
,
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
return
None
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
def
delete_doc_by_name
(
self
,
db
:
Session
,
name
:
str
)
->
bool
:
try
:
query
=
Document
.
delete
().
where
((
Document
.
name
==
name
))
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Document
).
filter_by
(
name
=
name
).
delete
()
return
True
except
:
return
False
Documents
=
DocumentsTable
(
DB
)
Documents
=
DocumentsTable
()
backend/apps/webui/models/files.py
View file @
df09d083
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
Column
,
String
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
JSONField
,
Base
import
json
...
...
@@ -18,15 +20,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
File
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
filename
=
TextField
()
meta
=
JSONField
()
created_at
=
BigIntegerField
()
class
File
(
Base
):
__tablename__
=
"file"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
filename
=
Column
(
String
)
meta
=
Column
(
JSONField
)
created_at
=
Column
(
BigInteger
)
class
FileModel
(
BaseModel
):
...
...
@@ -36,6 +37,7 @@ class FileModel(BaseModel):
meta
:
dict
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -57,11 +59,8 @@ class FileForm(BaseModel):
class
FilesTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
File
])
def
insert_new_file
(
self
,
user_id
:
str
,
form_data
:
FileForm
)
->
Optional
[
FileModel
]:
def
insert_new_file
(
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
FileForm
)
->
Optional
[
FileModel
]:
file
=
FileModel
(
**
{
**
form_data
.
model_dump
(),
...
...
@@ -71,42 +70,41 @@ class FilesTable:
)
try
:
result
=
File
.
create
(
**
file
.
model_dump
())
result
=
File
(
**
file
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
f
ile
return
F
ile
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
print
(
f
"Error creating tool:
{
e
}
"
)
return
None
def
get_file_by_id
(
self
,
id
:
str
)
->
Optional
[
FileModel
]:
def
get_file_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
FileModel
]:
try
:
file
=
File
.
get
(
File
.
id
==
id
)
return
FileModel
(
**
model_
to_dict
(
file
)
)
file
=
db
.
get
(
File
,
id
)
return
FileModel
.
model_
validate
(
file
)
except
:
return
None
def
get_files
(
self
)
->
List
[
FileModel
]:
return
[
FileModel
(
**
model_
to_dict
(
file
)
)
for
file
in
File
.
select
()]
def
get_files
(
self
,
db
:
Session
)
->
List
[
FileModel
]:
return
[
FileModel
.
model_
validate
(
file
)
for
file
in
db
.
query
(
File
).
all
()]
def
delete_file_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_file_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
try
:
query
=
File
.
delete
().
where
((
File
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
File
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
def
delete_all_files
(
self
)
->
bool
:
def
delete_all_files
(
self
,
db
:
Session
)
->
bool
:
try
:
query
=
File
.
delete
()
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
File
).
delete
()
return
True
except
:
return
False
Files
=
FilesTable
(
DB
)
Files
=
FilesTable
()
backend/apps/webui/models/functions.py
View file @
df09d083
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
Column
,
String
,
Text
,
BigInteger
,
Boolean
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
JSONField
,
Base
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -21,20 +23,19 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Function
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
name
=
TextField
()
type
=
TextField
()
content
=
TextField
()
meta
=
JSONField
()
valves
=
JSONField
()
is_active
=
BooleanField
(
default
=
False
)
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Function
(
Base
):
__tablename__
=
"function"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
name
=
Column
(
Text
)
type
=
Column
(
Text
)
content
=
Column
(
Text
)
meta
=
Column
(
JSONField
)
valves
=
Column
(
JSONField
)
is_active
=
Column
(
Boolean
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
FunctionMeta
(
BaseModel
):
...
...
@@ -53,6 +54,8 @@ class FunctionModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -82,12 +85,9 @@ class FunctionValves(BaseModel):
class
FunctionsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Function
])
def
insert_new_function
(
self
,
user_id
:
str
,
type
:
str
,
form_data
:
FunctionForm
self
,
db
:
Session
,
user_id
:
str
,
type
:
str
,
form_data
:
FunctionForm
)
->
Optional
[
FunctionModel
]:
function
=
FunctionModel
(
**
{
...
...
@@ -100,19 +100,22 @@ class FunctionsTable:
)
try
:
result
=
Function
.
create
(
**
function
.
model_dump
())
result
=
Function
(
**
function
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
f
unction
return
F
unction
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
print
(
f
"Error creating tool:
{
e
}
"
)
return
None
def
get_function_by_id
(
self
,
id
:
str
)
->
Optional
[
FunctionModel
]:
def
get_function_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
FunctionModel
]:
try
:
function
=
Function
.
get
(
Function
.
id
==
id
)
return
FunctionModel
(
**
model_
to_dict
(
function
)
)
function
=
db
.
get
(
Function
,
id
)
return
FunctionModel
.
model_
validate
(
function
)
except
:
return
None
...
...
@@ -211,14 +214,11 @@ class FunctionsTable:
def
update_function_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
FunctionModel
]:
try
:
query
=
Function
.
update
(
db
.
query
(
Function
).
filter_by
(
id
=
id
)
.
update
(
{
**
updated
,
updated_at
=
int
(
time
.
time
()),
).
where
(
Function
.
id
==
id
)
query
.
execute
()
function
=
Function
.
get
(
Function
.
id
==
id
)
return
FunctionModel
(
**
model_to_dict
(
function
))
"updated_at"
:
int
(
time
.
time
()),
})
return
self
.
get_function_by_id
(
db
,
id
)
except
:
return
None
...
...
@@ -235,14 +235,12 @@ class FunctionsTable:
except
:
return
None
def
delete_function_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_function_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
try
:
query
=
Function
.
delete
().
where
((
Function
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Function
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
Functions
=
FunctionsTable
(
DB
)
Functions
=
FunctionsTable
()
backend/apps/webui/models/memories.py
View file @
df09d083
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
Column
,
String
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
from
apps.webui.models.chats
import
Chats
import
time
...
...
@@ -14,15 +15,14 @@ import uuid
####################
class
Memory
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
content
=
TextField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Memory
(
Base
):
__tablename__
=
"memory"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
content
=
Column
(
String
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
MemoryModel
(
BaseModel
):
...
...
@@ -32,6 +32,8 @@ class MemoryModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -39,12 +41,10 @@ class MemoryModel(BaseModel):
class
MemoriesTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Memory
])
def
insert_new_memory
(
self
,
db
:
Session
,
user_id
:
str
,
content
:
str
,
)
->
Optional
[
MemoryModel
]:
...
...
@@ -59,74 +59,73 @@ class MemoriesTable:
"updated_at"
:
int
(
time
.
time
()),
}
)
result
=
Memory
.
create
(
**
memory
.
model_dump
())
result
=
Memory
(
**
memory
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
m
emory
return
M
emory
Model
.
model_validate
(
result
)
else
:
return
None
def
update_memory_by_id
(
self
,
db
:
Session
,
id
:
str
,
content
:
str
,
)
->
Optional
[
MemoryModel
]:
try
:
memory
=
Memory
.
get
(
Memory
.
id
==
id
)
memory
.
content
=
content
memory
.
updated_at
=
int
(
time
.
time
())
memory
.
save
()
return
MemoryModel
(
**
model_to_dict
(
memory
))
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
update
(
{
"content"
:
content
,
"updated_at"
:
int
(
time
.
time
())}
)
return
self
.
get_memory_by_id
(
db
,
id
)
except
:
return
None
def
get_memories
(
self
)
->
List
[
MemoryModel
]:
def
get_memories
(
self
,
db
:
Session
)
->
List
[
MemoryModel
]:
try
:
memories
=
Memory
.
select
()
return
[
MemoryModel
(
**
model_
to_dict
(
memory
)
)
for
memory
in
memories
]
memories
=
db
.
query
(
Memory
).
all
()
return
[
MemoryModel
.
model_
validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memories_by_user_id
(
self
,
user_id
:
str
)
->
List
[
MemoryModel
]:
def
get_memories_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
List
[
MemoryModel
]:
try
:
memories
=
Memory
.
select
().
where
(
Memory
.
user_id
==
user_id
)
return
[
MemoryModel
(
**
model_
to_dict
(
memory
)
)
for
memory
in
memories
]
memories
=
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
)
.
all
()
return
[
MemoryModel
.
model_
validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memory_by_id
(
self
,
i
d
)
->
Optional
[
MemoryModel
]:
def
get_memory_by_id
(
self
,
d
b
:
Session
,
id
:
str
)
->
Optional
[
MemoryModel
]:
try
:
memory
=
Memory
.
get
(
Memory
.
id
==
id
)
return
MemoryModel
(
**
model_
to_dict
(
memory
)
)
memory
=
db
.
get
(
Memory
,
id
)
return
MemoryModel
.
model_
validate
(
memory
)
except
:
return
None
def
delete_memory_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_memory_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
def
delete_memories_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
def
delete_memories_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
bool
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
user_id
==
user_id
)
query
.
execute
()
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
def
delete_memory_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
def
delete_memory_by_id_and_user_id
(
self
,
db
:
Session
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
,
Memory
.
user_id
==
user_id
)
query
.
execute
()
db
.
query
(
Memory
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
Memories
=
MemoriesTable
(
DB
)
Memories
=
MemoriesTable
()
backend/apps/webui/models/models.py
View file @
df09d083
...
...
@@ -2,13 +2,11 @@ import json
import
logging
from
typing
import
Optional
import
peewee
as
pw
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
...
...
@@ -46,41 +44,42 @@ class ModelMeta(BaseModel):
pass
class
Model
(
pw
.
Model
):
id
=
pw
.
TextField
(
unique
=
True
)
class
Model
(
Base
):
__tablename__
=
"model"
id
=
Column
(
String
,
primary_key
=
True
)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id
=
pw
.
TextField
(
)
user_id
=
Column
(
String
)
base_model_id
=
pw
.
TextField
(
null
=
True
)
base_model_id
=
Column
(
String
,
nullable
=
True
)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name
=
pw
.
TextField
(
)
name
=
Column
(
String
)
"""
The human-readable display name of the model.
"""
params
=
JSONField
(
)
params
=
Column
(
JSONField
)
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta
=
JSONField
(
)
meta
=
Column
(
JSONField
)
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Meta
:
database
=
DB
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
ModelModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
id
:
str
user_id
:
str
base_model_id
:
Optional
[
str
]
=
None
...
...
@@ -115,15 +114,9 @@ class ModelForm(BaseModel):
class
ModelsTable
:
def
__init__
(
self
,
db
:
pw
.
SqliteDatabase
|
pw
.
PostgresqlDatabase
,
):
self
.
db
=
db
self
.
db
.
create_tables
([
Model
])
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
self
,
db
:
Session
,
form_data
:
ModelForm
,
user_id
:
str
)
->
Optional
[
ModelModel
]:
model
=
ModelModel
(
**
{
...
...
@@ -134,46 +127,50 @@ class ModelsTable:
}
)
try
:
result
=
Model
.
create
(
**
model
.
model_dump
())
result
=
Model
(
**
model
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
model
return
ModelModel
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
print
(
e
)
return
None
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
return
[
ModelModel
(
**
model_
to_dict
(
model
)
)
for
model
in
Model
.
select
()]
def
get_all_models
(
self
,
db
:
Session
)
->
List
[
ModelModel
]:
return
[
ModelModel
.
model_
validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
def
get_model_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_
to_dict
(
model
)
)
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_
validate
(
model
)
except
:
return
None
def
update_model_by_id
(
self
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
def
update_model_by_id
(
self
,
db
:
Session
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
try
:
# update only the fields that are present in the model
query
=
Model
.
update
(
**
model
.
model_dump
()).
where
(
Model
.
id
==
id
)
query
.
execute
()
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_
to_dict
(
model
)
)
model
=
db
.
query
(
Model
).
get
(
id
)
model
.
update
(
**
model
.
model_dump
()
)
db
.
commit
()
db
.
refresh
(
model
)
return
ModelModel
.
model_
validate
(
model
)
except
Exception
as
e
:
print
(
e
)
return
None
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_model_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
try
:
query
=
Model
.
delete
().
where
(
Model
.
id
==
id
)
query
.
execute
()
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
Models
=
ModelsTable
(
DB
)
Models
=
ModelsTable
()
backend/apps/webui/models/prompts.py
View file @
df09d083
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
import
json
...
...
@@ -16,15 +14,14 @@ import json
####################
class
Prompt
(
Model
):
command
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
title
=
TextField
()
content
=
TextField
()
timestamp
=
BigIntegerField
()
class
Prompt
(
Base
):
__tablename__
=
"prompt"
class
Meta
:
database
=
DB
command
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
title
=
Column
(
String
)
content
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
PromptModel
(
BaseModel
):
...
...
@@ -34,6 +31,8 @@ class PromptModel(BaseModel):
content
:
str
timestamp
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -48,12 +47,8 @@ class PromptForm(BaseModel):
class
PromptsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Prompt
])
def
insert_new_prompt
(
self
,
user_id
:
str
,
form_data
:
PromptForm
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
prompt
=
PromptModel
(
**
{
...
...
@@ -66,53 +61,48 @@ class PromptsTable:
)
try
:
result
=
Prompt
.
create
(
**
prompt
.
model_dump
())
result
=
Prompt
(
**
prompt
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
p
rompt
return
P
rompt
Model
.
model_validate
(
result
)
else
:
return
None
except
:
except
Exception
as
e
:
return
None
def
get_prompt_by_command
(
self
,
command
:
str
)
->
Optional
[
PromptModel
]:
def
get_prompt_by_command
(
self
,
db
:
Session
,
command
:
str
)
->
Optional
[
PromptModel
]:
try
:
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
return
PromptModel
(
**
model_
to_dict
(
prompt
)
)
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
)
.
first
()
return
PromptModel
.
model_
validate
(
prompt
)
except
:
return
None
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
return
[
PromptModel
(
**
model_to_dict
(
prompt
))
for
prompt
in
Prompt
.
select
()
# .limit(limit).offset(skip)
]
def
get_prompts
(
self
,
db
:
Session
)
->
List
[
PromptModel
]:
return
[
PromptModel
.
model_validate
(
prompt
)
for
prompt
in
db
.
query
(
Prompt
).
all
()]
def
update_prompt_by_command
(
self
,
command
:
str
,
form_data
:
PromptForm
self
,
db
:
Session
,
command
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
try
:
query
=
Prompt
.
update
(
title
=
form_data
.
title
,
content
=
form_data
.
content
,
timestamp
=
int
(
time
.
time
()),
).
where
(
Prompt
.
command
==
command
)
query
.
execute
()
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
return
PromptModel
(
**
model_to_dict
(
prompt
))
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
update
(
{
"title"
:
form_data
.
title
,
"content"
:
form_data
.
content
,
"timestamp"
:
int
(
time
.
time
()),
}
)
return
self
.
get_prompt_by_command
(
db
,
command
)
except
:
return
None
def
delete_prompt_by_command
(
self
,
command
:
str
)
->
bool
:
def
delete_prompt_by_command
(
self
,
db
:
Session
,
command
:
str
)
->
bool
:
try
:
query
=
Prompt
.
delete
().
where
((
Prompt
.
command
==
command
))
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
delete
()
return
True
except
:
return
False
Prompts
=
PromptsTable
(
DB
)
Prompts
=
PromptsTable
()
backend/apps/webui/models/tags.py
View file @
df09d083
from
pydantic
import
BaseModel
from
typing
import
List
,
Union
,
Optional
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
json
import
uuid
import
time
import
logging
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
from
config
import
SRC_LOG_LEVELS
...
...
@@ -20,25 +21,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Tag
(
Model
):
id
=
CharField
(
unique
=
True
)
name
=
CharField
()
user_id
=
CharField
()
data
=
TextField
(
null
=
True
)
class
Tag
(
Base
):
__tablename__
=
"tag"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
)
user_id
=
Column
(
String
)
data
=
Column
(
String
,
nullable
=
True
)
class
ChatIdTag
(
Model
):
id
=
CharField
(
unique
=
True
)
tag_name
=
CharField
()
chat_id
=
CharField
()
user_id
=
CharField
()
timestamp
=
BigIntegerField
()
class
ChatIdTag
(
Base
):
__tablename__
=
"chatidtag"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
tag_name
=
Column
(
String
)
chat_id
=
Column
(
String
)
user_id
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
TagModel
(
BaseModel
):
...
...
@@ -47,6 +46,8 @@ class TagModel(BaseModel):
user_id
:
str
data
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
class
ChatIdTagModel
(
BaseModel
):
id
:
str
...
...
@@ -55,6 +56,8 @@ class ChatIdTagModel(BaseModel):
user_id
:
str
timestamp
:
int
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -75,37 +78,39 @@ class ChatTagsResponse(BaseModel):
class
TagTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
db
.
create_tables
([
Tag
,
ChatIdTag
])
def
insert_new_tag
(
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
def
insert_new_tag
(
self
,
db
:
Session
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
id
=
str
(
uuid
.
uuid4
())
tag
=
TagModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"name"
:
name
})
try
:
result
=
Tag
.
create
(
**
tag
.
model_dump
())
result
=
Tag
(
**
tag
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
t
ag
return
T
ag
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
return
None
def
get_tag_by_name_and_user_id
(
self
,
name
:
str
,
user_id
:
str
self
,
db
:
Session
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
try
:
tag
=
Tag
.
get
(
Tag
.
name
==
name
,
Tag
.
user_id
==
user_id
)
return
TagModel
(
**
model_
to_dict
(
tag
)
)
tag
=
db
.
query
(
Tag
).
filter
(
name
=
name
,
user_id
=
user_id
)
.
first
()
return
TagModel
.
model_
validate
(
tag
)
except
Exception
as
e
:
return
None
def
add_tag_to_chat
(
self
,
user_id
:
str
,
form_data
:
ChatIdTagForm
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
ChatIdTagForm
)
->
Optional
[
ChatIdTagModel
]:
tag
=
self
.
get_tag_by_name_and_user_id
(
form_data
.
tag_name
,
user_id
)
tag
=
self
.
get_tag_by_name_and_user_id
(
db
,
form_data
.
tag_name
,
user_id
)
if
tag
==
None
:
tag
=
self
.
insert_new_tag
(
form_data
.
tag_name
,
user_id
)
tag
=
self
.
insert_new_tag
(
db
,
form_data
.
tag_name
,
user_id
)
id
=
str
(
uuid
.
uuid4
())
chatIdTag
=
ChatIdTagModel
(
...
...
@@ -118,120 +123,135 @@ class TagTable:
}
)
try
:
result
=
ChatIdTag
.
create
(
**
chatIdTag
.
model_dump
())
result
=
ChatIdTag
(
**
chatIdTag
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
c
hatIdTag
return
C
hatIdTag
Model
.
model_validate
(
result
)
else
:
return
None
except
:
return
None
def
get_tags_by_user_id
(
self
,
user_id
:
str
)
->
List
[
TagModel
]:
def
get_tags_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
List
[
TagModel
]:
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
(
ChatIdTag
.
user_id
==
user_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
return
[
TagModel
(
**
model_to_dict
(
tag
))
for
tag
in
Tag
.
select
()
.
where
(
Tag
.
user_id
==
user_id
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
TagModel
.
model_validate
(
tag
)
for
tag
in
(
db
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
def
get_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
self
,
db
:
Session
,
chat_id
:
str
,
user_id
:
str
)
->
List
[
TagModel
]:
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
chat_id
==
chat_id
))
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
chat_id
=
chat_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
return
[
TagModel
(
**
model_to_dict
(
tag
))
for
tag
in
Tag
.
select
()
.
where
(
Tag
.
user_id
==
user_id
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
TagModel
.
model_validate
(
tag
)
for
tag
in
(
db
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
def
get_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
Optional
[
ChatIdTagModel
]:
self
,
db
:
Session
,
tag_name
:
str
,
user_id
:
str
)
->
List
[
ChatIdTagModel
]:
return
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
))
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
tag_name
==
tag_name
))
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
ChatIdTagModel
.
model_validate
(
chat_id_tag
)
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
tag_name
=
tag_name
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
def
count_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
self
,
db
:
Session
,
tag_name
:
str
,
user_id
:
str
)
->
int
:
return
(
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
))
.
count
()
)
return
db
.
query
(
ChatIdTag
).
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
).
count
()
def
delete_tag_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
bool
:
def
delete_tag_by_tag_name_and_user_id
(
self
,
db
:
Session
,
tag_name
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
res
=
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
delete
()
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
db
,
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
return
True
except
Exception
as
e
:
log
.
error
(
f
"delete_tag:
{
e
}
"
)
return
False
def
delete_tag_by_tag_name_and_chat_id_and_user_id
(
self
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
self
,
db
:
Session
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
chat_id
==
chat
_id
)
&
(
ChatIdTag
.
user_id
==
user_id
)
res
=
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
chat_id
=
chat_id
,
user_id
=
user
_id
)
.
delete
(
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
db
,
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
return
True
except
Exception
as
e
:
log
.
error
(
f
"delete_tag:
{
e
}
"
)
return
False
def
delete_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
tags
=
self
.
get_tags_by_chat_id_and_user_id
(
chat_id
,
user_id
)
def
delete_tags_by_chat_id_and_user_id
(
self
,
db
:
Session
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
tags
=
self
.
get_tags_by_chat_id_and_user_id
(
db
,
chat_id
,
user_id
)
for
tag
in
tags
:
self
.
delete_tag_by_tag_name_and_chat_id_and_user_id
(
tag
.
tag_name
,
chat_id
,
user_id
db
,
tag
.
tag_name
,
chat_id
,
user_id
)
return
True
Tags
=
TagTable
(
DB
)
Tags
=
TagTable
()
backend/apps/webui/models/tools.py
View file @
df09d083
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
JSONField
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -21,19 +22,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Tool
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
name
=
TextField
()
content
=
TextField
()
specs
=
JSONField
()
meta
=
JSONField
()
valves
=
JSONField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Tool
(
Base
):
__tablename__
=
"tool"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
name
=
Column
(
String
)
content
=
Column
(
String
)
specs
=
Column
(
JSONField
)
meta
=
Column
(
JSONField
)
valves
=
Column
(
JSONField
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
ToolMeta
(
BaseModel
):
...
...
@@ -51,6 +51,8 @@ class ToolModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -78,12 +80,9 @@ class ToolValves(BaseModel):
class
ToolsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Tool
])
def
insert_new_tool
(
self
,
user_id
:
str
,
form_data
:
ToolForm
,
specs
:
List
[
dict
]
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
ToolForm
,
specs
:
List
[
dict
]
)
->
Optional
[
ToolModel
]:
tool
=
ToolModel
(
**
{
...
...
@@ -96,24 +95,27 @@ class ToolsTable:
)
try
:
result
=
Tool
.
create
(
**
tool
.
model_dump
())
result
=
Tool
(
**
tool
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
t
ool
return
T
ool
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
print
(
f
"Error creating tool:
{
e
}
"
)
return
None
def
get_tool_by_id
(
self
,
id
:
str
)
->
Optional
[
ToolModel
]:
def
get_tool_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ToolModel
]:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolModel
(
**
model_
to_dict
(
tool
)
)
tool
=
db
.
get
(
Tool
,
id
)
return
ToolModel
.
model_
validate
(
tool
)
except
:
return
None
def
get_tools
(
self
)
->
List
[
ToolModel
]:
return
[
ToolModel
(
**
model_
to_dict
(
tool
)
)
for
tool
in
Tool
.
select
()]
def
get_tools
(
self
,
db
:
Session
)
->
List
[
ToolModel
]:
return
[
ToolModel
.
model_
validate
(
tool
)
for
tool
in
db
.
query
(
Tool
).
all
()]
def
get_tool_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
...
...
@@ -180,25 +182,19 @@ class ToolsTable:
def
update_tool_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
ToolModel
]:
try
:
query
=
Tool
.
update
(
**
updated
,
updated_at
=
int
(
time
.
time
()),
).
where
(
Tool
.
id
==
id
)
query
.
execute
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolModel
(
**
model_to_dict
(
tool
))
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
**
updated
,
"updated_at"
:
int
(
time
.
time
())}
)
return
self
.
get_tool_by_id
(
db
,
id
)
except
:
return
None
def
delete_tool_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_tool_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
try
:
query
=
Tool
.
delete
().
where
((
Tool
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
Tools
=
ToolsTable
(
DB
)
Tools
=
ToolsTable
()
backend/apps/webui/models/users.py
View file @
df09d083
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
,
parse_obj_as
from
typing
import
List
,
Union
,
Optional
import
time
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
sqlalchemy.orm
import
Session
from
utils.misc
import
get_gravatar_url
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
from
apps.webui.models.chats
import
Chats
####################
...
...
@@ -13,25 +15,24 @@ from apps.webui.models.chats import Chats
####################
class
User
(
Model
):
id
=
CharField
(
unique
=
True
)
name
=
CharField
()
email
=
CharField
()
role
=
CharField
()
profile_image_url
=
TextField
()
class
User
(
Base
):
__tablename__
=
"user"
last_active_at
=
BigIntegerField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
id
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
)
email
=
Column
(
String
)
role
=
Column
(
String
)
profile_image_url
=
Column
(
String
)
api_key
=
CharField
(
null
=
True
,
unique
=
True
)
settings
=
JSONField
(
null
=
True
)
info
=
JSONField
(
null
=
True
)
last_active_at
=
Column
(
BigInteger
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
oauth_sub
=
TextField
(
null
=
True
,
unique
=
True
)
api_key
=
Column
(
String
,
nullable
=
True
,
unique
=
True
)
settings
=
Column
(
JSONField
,
nullable
=
True
)
info
=
Column
(
JSONField
,
nullable
=
True
)
class
Meta
:
database
=
DB
oauth_sub
=
Column
(
Text
,
unique
=
True
)
class
UserSettings
(
BaseModel
):
...
...
@@ -41,6 +42,8 @@ class UserSettings(BaseModel):
class
UserModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
id
:
str
name
:
str
email
:
str
...
...
@@ -76,12 +79,10 @@ class UserUpdateForm(BaseModel):
class
UsersTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
User
])
def
insert_new_user
(
self
,
db
:
Session
,
id
:
str
,
name
:
str
,
email
:
str
,
...
...
@@ -102,30 +103,33 @@ class UsersTable:
"oauth_sub"
:
oauth_sub
,
}
)
result
=
User
.
create
(
**
user
.
model_dump
())
result
=
User
(
**
user
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
user
else
:
return
None
def
get_user_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
except
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
Exception
as
e
:
return
None
def
get_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_api_key
(
self
,
db
:
Session
,
api_key
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
api_key
==
api_key
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
api_key
=
api_key
)
.
first
()
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_email
(
self
,
db
:
Session
,
email
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
email
==
email
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
email
=
email
)
.
first
()
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
...
...
@@ -136,88 +140,94 @@ class UsersTable:
except
:
return
None
def
get_users
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
return
[
UserModel
(
**
model_to_dict
(
user
))
for
user
in
User
.
select
()
# .limit(limit).offset(skip)
]
def
get_users
(
self
,
db
:
Session
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
users
=
(
db
.
query
(
User
)
# .offset(skip).limit(limit)
.
all
()
)
return
[
UserModel
.
model_validate
(
user
)
for
user
in
users
]
def
get_num_users
(
self
)
->
Optional
[
int
]:
return
User
.
select
(
).
count
()
def
get_num_users
(
self
,
db
:
Session
)
->
Optional
[
int
]:
return
db
.
query
(
User
).
count
()
def
get_first_user
(
self
)
->
UserModel
:
def
get_first_user
(
self
,
db
:
Session
)
->
UserModel
:
try
:
user
=
User
.
select
(
).
order_by
(
User
.
created_at
).
first
()
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
order_by
(
User
.
created_at
).
first
()
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_role_by_id
(
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
def
update_user_role_by_id
(
self
,
db
:
Session
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
role
=
role
).
where
(
User
.
id
==
id
)
query
.
execute
()
db
.
query
(
User
).
filter_by
(
id
=
id
)
.
update
(
{
"
role
"
:
role
}
)
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_profile_image_url_by_id
(
self
,
id
:
str
,
profile_image_url
:
str
self
,
db
:
Session
,
id
:
str
,
profile_image_url
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
profile_image_url
=
profile_image_url
).
wher
e
(
User
.
id
==
id
db
.
query
(
User
).
filter_by
(
id
=
id
).
updat
e
(
{
"profile_image_url"
:
profile_image_url
}
)
query
.
execute
()
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_last_active_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
def
update_user_last_active_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
last_active_at
=
int
(
time
.
time
())).
where
(
User
.
id
==
id
)
query
.
execute
()
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"last_active_at"
:
int
(
time
.
time
())})
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_oauth_sub_by_id
(
self
,
id
:
str
,
oauth_sub
:
str
self
,
db
:
Session
,
id
:
str
,
oauth_sub
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
oauth_sub
=
oauth_sub
).
where
(
User
.
id
==
id
)
query
.
execute
()
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"oauth_sub"
:
oauth_sub
})
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
def
update_user_by_id
(
self
,
db
:
Session
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
**
updated
)
.
where
(
User
.
id
==
id
)
query
.
execute
()
db
.
query
(
User
).
filter_by
(
id
=
id
)
.
update
(
updated
)
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
# return UserModel(**user.dict())
except
Exception
as
e
:
return
None
def
delete_user_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_user_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
try
:
# Delete User Chats
result
=
Chats
.
delete_chats_by_user_id
(
id
)
result
=
Chats
.
delete_chats_by_user_id
(
db
,
id
)
if
result
:
# Delete User
query
=
User
.
delete
().
where
(
User
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
User
).
filter_by
(
id
=
id
).
delete
(
)
db
.
commit
()
return
True
else
:
...
...
@@ -225,21 +235,20 @@ class UsersTable:
except
:
return
False
def
update_user_api_key_by_id
(
self
,
id
:
str
,
api_key
:
str
)
->
str
:
def
update_user_api_key_by_id
(
self
,
db
:
Session
,
id
:
str
,
api_key
:
str
)
->
str
:
try
:
query
=
User
.
update
(
api_key
=
api_key
).
where
(
User
.
id
==
id
)
result
=
query
.
execute
()
result
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"api_key"
:
api_key
})
db
.
commit
()
return
True
if
result
==
1
else
False
except
:
return
False
def
get_user_api_key_by_id
(
self
,
id
:
str
)
->
Optional
[
str
]:
def
get_user_api_key_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
str
]:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
user
.
api_key
except
:
except
Exception
as
e
:
return
None
Users
=
UsersTable
(
DB
)
Users
=
UsersTable
()
backend/apps/webui/routers/auths.py
View file @
df09d083
...
...
@@ -10,6 +10,7 @@ import re
import
uuid
import
csv
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.auths
import
(
SigninForm
,
SignupForm
,
...
...
@@ -78,10 +79,13 @@ async def get_session_user(
@
router
.
post
(
"/update/profile"
,
response_model
=
UserResponse
)
async
def
update_profile
(
form_data
:
UpdateProfileForm
,
session_user
=
Depends
(
get_current_user
)
form_data
:
UpdateProfileForm
,
session_user
=
Depends
(
get_current_user
),
db
=
Depends
(
get_db
),
):
if
session_user
:
user
=
Users
.
update_user_by_id
(
db
,
session_user
.
id
,
{
"profile_image_url"
:
form_data
.
profile_image_url
,
"name"
:
form_data
.
name
},
)
...
...
@@ -100,16 +104,18 @@ async def update_profile(
@
router
.
post
(
"/update/password"
,
response_model
=
bool
)
async
def
update_password
(
form_data
:
UpdatePasswordForm
,
session_user
=
Depends
(
get_current_user
)
form_data
:
UpdatePasswordForm
,
session_user
=
Depends
(
get_current_user
),
db
=
Depends
(
get_db
),
):
if
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
ACTION_PROHIBITED
)
if
session_user
:
user
=
Auths
.
authenticate_user
(
session_user
.
email
,
form_data
.
password
)
user
=
Auths
.
authenticate_user
(
db
,
session_user
.
email
,
form_data
.
password
)
if
user
:
hashed
=
get_password_hash
(
form_data
.
new_password
)
return
Auths
.
update_user_password_by_id
(
user
.
id
,
hashed
)
return
Auths
.
update_user_password_by_id
(
db
,
user
.
id
,
hashed
)
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_PASSWORD
)
else
:
...
...
@@ -122,7 +128,7 @@ async def update_password(
@
router
.
post
(
"/signin"
,
response_model
=
SigninResponse
)
async
def
signin
(
request
:
Request
,
response
:
Response
,
form_data
:
SigninForm
):
async
def
signin
(
request
:
Request
,
response
:
Response
,
form_data
:
SigninForm
,
db
=
Depends
(
get_db
)
):
if
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
:
if
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
not
in
request
.
headers
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_TRUSTED_HEADER
)
...
...
@@ -133,32 +139,34 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
trusted_name
=
request
.
headers
.
get
(
WEBUI_AUTH_TRUSTED_NAME_HEADER
,
trusted_email
)
if
not
Users
.
get_user_by_email
(
trusted_email
.
lower
()):
if
not
Users
.
get_user_by_email
(
db
,
trusted_email
.
lower
()):
await
signup
(
request
,
SignupForm
(
email
=
trusted_email
,
password
=
str
(
uuid
.
uuid4
()),
name
=
trusted_name
),
db
,
)
user
=
Auths
.
authenticate_user_by_trusted_header
(
trusted_email
)
user
=
Auths
.
authenticate_user_by_trusted_header
(
db
,
trusted_email
)
elif
WEBUI_AUTH
==
False
:
admin_email
=
"admin@localhost"
admin_password
=
"admin"
if
Users
.
get_user_by_email
(
admin_email
.
lower
()):
user
=
Auths
.
authenticate_user
(
admin_email
.
lower
(),
admin_password
)
if
Users
.
get_user_by_email
(
db
,
admin_email
.
lower
()):
user
=
Auths
.
authenticate_user
(
db
,
admin_email
.
lower
(),
admin_password
)
else
:
if
Users
.
get_num_users
()
!=
0
:
if
Users
.
get_num_users
(
db
)
!=
0
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EXISTING_USERS
)
await
signup
(
request
,
SignupForm
(
email
=
admin_email
,
password
=
admin_password
,
name
=
"User"
),
db
,
)
user
=
Auths
.
authenticate_user
(
admin_email
.
lower
(),
admin_password
)
user
=
Auths
.
authenticate_user
(
db
,
admin_email
.
lower
(),
admin_password
)
else
:
user
=
Auths
.
authenticate_user
(
form_data
.
email
.
lower
(),
form_data
.
password
)
user
=
Auths
.
authenticate_user
(
db
,
form_data
.
email
.
lower
(),
form_data
.
password
)
if
user
:
token
=
create_token
(
...
...
@@ -192,7 +200,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@
router
.
post
(
"/signup"
,
response_model
=
SigninResponse
)
async
def
signup
(
request
:
Request
,
response
:
Response
,
form_data
:
SignupForm
):
async
def
signup
(
request
:
Request
,
response
:
Response
,
form_data
:
SignupForm
,
db
=
Depends
(
get_db
)
):
if
not
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
and
WEBUI_AUTH
:
raise
HTTPException
(
status
.
HTTP_403_FORBIDDEN
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
...
...
@@ -203,17 +211,18 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
INVALID_EMAIL_FORMAT
)
if
Users
.
get_user_by_email
(
form_data
.
email
.
lower
()):
if
Users
.
get_user_by_email
(
db
,
form_data
.
email
.
lower
()):
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EMAIL_TAKEN
)
try
:
role
=
(
"admin"
if
Users
.
get_num_users
()
==
0
if
Users
.
get_num_users
(
db
)
==
0
else
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
)
hashed
=
get_password_hash
(
form_data
.
password
)
user
=
Auths
.
insert_new_auth
(
db
,
form_data
.
email
.
lower
(),
hashed
,
form_data
.
name
,
...
...
@@ -267,14 +276,16 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@
router
.
post
(
"/add"
,
response_model
=
SigninResponse
)
async
def
add_user
(
form_data
:
AddUserForm
,
user
=
Depends
(
get_admin_user
)):
async
def
add_user
(
form_data
:
AddUserForm
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
)
):
if
not
validate_email_format
(
form_data
.
email
.
lower
()):
raise
HTTPException
(
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
INVALID_EMAIL_FORMAT
)
if
Users
.
get_user_by_email
(
form_data
.
email
.
lower
()):
if
Users
.
get_user_by_email
(
db
,
form_data
.
email
.
lower
()):
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EMAIL_TAKEN
)
try
:
...
...
@@ -282,6 +293,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
print
(
form_data
)
hashed
=
get_password_hash
(
form_data
.
password
)
user
=
Auths
.
insert_new_auth
(
db
,
form_data
.
email
.
lower
(),
hashed
,
form_data
.
name
,
...
...
@@ -312,7 +324,9 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@
router
.
get
(
"/admin/details"
)
async
def
get_admin_details
(
request
:
Request
,
user
=
Depends
(
get_current_user
)):
async
def
get_admin_details
(
request
:
Request
,
user
=
Depends
(
get_current_user
),
db
=
Depends
(
get_db
)
):
if
request
.
app
.
state
.
config
.
SHOW_ADMIN_DETAILS
:
admin_email
=
request
.
app
.
state
.
config
.
ADMIN_EMAIL
admin_name
=
None
...
...
@@ -320,11 +334,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
print
(
admin_email
,
admin_name
)
if
admin_email
:
admin
=
Users
.
get_user_by_email
(
admin_email
)
admin
=
Users
.
get_user_by_email
(
db
,
admin_email
)
if
admin
:
admin_name
=
admin
.
name
else
:
admin
=
Users
.
get_first_user
()
admin
=
Users
.
get_first_user
(
db
)
if
admin
:
admin_email
=
admin
.
email
admin_name
=
admin
.
name
...
...
@@ -397,9 +411,9 @@ async def update_admin_config(
# create api key
@
router
.
post
(
"/api_key"
,
response_model
=
ApiKey
)
async
def
create_api_key_
(
user
=
Depends
(
get_current_user
)):
async
def
create_api_key_
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
api_key
=
create_api_key
()
success
=
Users
.
update_user_api_key_by_id
(
user
.
id
,
api_key
)
success
=
Users
.
update_user_api_key_by_id
(
db
,
user
.
id
,
api_key
)
if
success
:
return
{
"api_key"
:
api_key
,
...
...
@@ -410,15 +424,15 @@ async def create_api_key_(user=Depends(get_current_user)):
# delete api key
@
router
.
delete
(
"/api_key"
,
response_model
=
bool
)
async
def
delete_api_key
(
user
=
Depends
(
get_current_user
)):
success
=
Users
.
update_user_api_key_by_id
(
user
.
id
,
None
)
async
def
delete_api_key
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
success
=
Users
.
update_user_api_key_by_id
(
db
,
user
.
id
,
None
)
return
success
# get api key
@
router
.
get
(
"/api_key"
,
response_model
=
ApiKey
)
async
def
get_api_key
(
user
=
Depends
(
get_current_user
)):
api_key
=
Users
.
get_user_api_key_by_id
(
user
.
id
)
async
def
get_api_key
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
api_key
=
Users
.
get_user_api_key_by_id
(
db
,
user
.
id
)
if
api_key
:
return
{
"api_key"
:
api_key
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment