Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
f9e3c47d
Commit
f9e3c47d
authored
Jul 09, 2024
by
Michael Poluektov
Browse files
rebase
parents
49b4211c
24ef5af2
Changes
149
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1370 additions
and
951 deletions
+1370
-951
.github/workflows/integration-test.yml
.github/workflows/integration-test.yml
+32
-6
backend/alembic.ini
backend/alembic.ini
+114
-0
backend/apps/rag/main.py
backend/apps/rag/main.py
+2
-1
backend/apps/webui/internal/db.py
backend/apps/webui/internal/db.py
+80
-24
backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py
.../apps/webui/internal/migrations/017_add_user_oauth_sub.py
+0
-4
backend/apps/webui/internal/migrations/README.md
backend/apps/webui/internal/migrations/README.md
+0
-21
backend/apps/webui/main.py
backend/apps/webui/main.py
+82
-2
backend/apps/webui/models/auths.py
backend/apps/webui/models/auths.py
+63
-56
backend/apps/webui/models/chats.py
backend/apps/webui/models/chats.py
+206
-187
backend/apps/webui/models/documents.py
backend/apps/webui/models/documents.py
+69
-62
backend/apps/webui/models/files.py
backend/apps/webui/models/files.py
+64
-50
backend/apps/webui/models/functions.py
backend/apps/webui/models/functions.py
+125
-106
backend/apps/webui/models/memories.py
backend/apps/webui/models/memories.py
+92
-76
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+53
-42
backend/apps/webui/models/prompts.py
backend/apps/webui/models/prompts.py
+47
-46
backend/apps/webui/models/tags.py
backend/apps/webui/models/tags.py
+141
-106
backend/apps/webui/models/tools.py
backend/apps/webui/models/tools.py
+74
-64
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+118
-94
backend/apps/webui/routers/chats.py
backend/apps/webui/routers/chats.py
+5
-3
backend/apps/webui/routers/documents.py
backend/apps/webui/routers/documents.py
+3
-1
No files found.
.github/workflows/integration-test.yml
View file @
f9e3c47d
...
...
@@ -35,6 +35,10 @@ jobs:
done
echo "Service is up!"
-
name
:
Delete Docker build cache
run
:
|
docker builder prune --all --force
-
name
:
Preload Ollama model
run
:
|
docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
...
...
@@ -43,7 +47,7 @@ jobs:
uses
:
cypress-io/github-action@v6
with
:
browser
:
chrome
wait-on
:
"
http://localhost:3000
"
wait-on
:
'
http://localhost:3000
'
config
:
baseUrl=http://localhost:3000
-
uses
:
actions/upload-artifact@v4
...
...
@@ -67,6 +71,28 @@ jobs:
path
:
compose-logs.txt
if-no-files-found
:
ignore
# pytest:
# name: Run Backend Tests
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python
# uses: actions/setup-python@v4
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# pip install -r backend/requirements.txt
# - name: pytest run
# run: |
# ls -al
# cd backend
# PYTHONPATH=. pytest . -o log_cli=true -o log_cli_level=INFO
migration_test
:
name
:
Run Migration Tests
runs-on
:
ubuntu-latest
...
...
@@ -126,11 +152,11 @@ jobs:
cd backend
uvicorn main:app --port "8080" --forwarded-allow-ips '*' &
UVICORN_PID=$!
# Wait up to
2
0 seconds for the server to start
for i in {1..
2
0}; do
# Wait up to
4
0 seconds for the server to start
for i in {1..
4
0}; do
curl -s http://localhost:8080/api/config > /dev/null && break
sleep 1
if [ $i -eq
2
0 ]; then
if [ $i -eq
4
0 ]; then
echo "Server failed to start"
kill -9 $UVICORN_PID
exit 1
...
...
@@ -171,7 +197,7 @@ jobs:
fi
# Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health
/db
)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check"
exit 1
...
...
@@ -183,7 +209,7 @@ jobs:
cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health
/db
)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1
...
...
backend/alembic.ini
0 → 100644
View file @
f9e3c47d
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location
=
migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path
=
.
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator
=
os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = REPLACE_WITH_DATABASE_URL
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys
=
root,sqlalchemy,alembic
[handlers]
keys
=
console
[formatters]
keys
=
generic
[logger_root]
level
=
WARN
handlers
=
console
qualname
=
[logger_sqlalchemy]
level
=
WARN
handlers
=
qualname
=
sqlalchemy.engine
[logger_alembic]
level
=
INFO
handlers
=
qualname
=
alembic
[handler_console]
class
=
StreamHandler
args
=
(sys.stderr,)
level
=
NOTSET
formatter
=
generic
[formatter_generic]
format
=
%(levelname)-5.5s [%(name)s] %(message)s
datefmt
=
%H:%M:%S
backend/apps/rag/main.py
View file @
f9e3c47d
...
...
@@ -1004,10 +1004,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
return
True
except
Exception
as
e
:
log
.
exception
(
e
)
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
return
True
log
.
exception
(
e
)
return
False
...
...
backend/apps/webui/internal/db.py
View file @
f9e3c47d
import
os
import
logging
import
json
from
contextlib
import
contextmanager
from
peewee
import
*
from
peewee_migrate
import
Router
from
apps.webui.internal.wrappers
import
register_connection
from
typing
import
Optional
,
Any
from
typing_extensions
import
Self
from
sqlalchemy
import
create_engine
,
types
,
Dialect
from
sqlalchemy.ext.declarative
import
declarative_base
from
sqlalchemy.orm
import
sessionmaker
,
scoped_session
from
sqlalchemy.sql.type_api
import
_T
from
config
import
SRC_LOG_LEVELS
,
DATA_DIR
,
DATABASE_URL
,
BACKEND_DIR
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
class
JSONField
(
TextField
):
class
JSONField
(
types
.
TypeDecorator
):
impl
=
types
.
Text
cache_ok
=
True
def
process_bind_param
(
self
,
value
:
Optional
[
_T
],
dialect
:
Dialect
)
->
Any
:
return
json
.
dumps
(
value
)
def
process_result_value
(
self
,
value
:
Optional
[
_T
],
dialect
:
Dialect
)
->
Any
:
if
value
is
not
None
:
return
json
.
loads
(
value
)
def
copy
(
self
,
**
kw
:
Any
)
->
Self
:
return
JSONField
(
self
.
impl
.
length
)
def
db_value
(
self
,
value
):
return
json
.
dumps
(
value
)
...
...
@@ -30,25 +51,60 @@ else:
pass
# The `register_connection` function encapsulates the logic for setting up
# the database connection based on the connection string, while `connect`
# is a Peewee-specific method to manage the connection state and avoid errors
# when a connection is already open.
try
:
DB
=
register_connection
(
DATABASE_URL
)
log
.
info
(
f
"Connected to a
{
DB
.
__class__
.
__name__
}
database."
)
except
Exception
as
e
:
# Workaround to handle the peewee migration
# This is required to ensure the peewee migration is handled before the alembic migration
def
handle_peewee_migration
(
DATABASE_URL
):
try
:
# Replace the postgresql:// with postgres:// and %40 with @ in the DATABASE_URL
db
=
register_connection
(
DATABASE_URL
.
replace
(
"postgresql://"
,
"postgres://"
).
replace
(
"%40"
,
"@"
)
)
migrate_dir
=
BACKEND_DIR
/
"apps"
/
"webui"
/
"internal"
/
"migrations"
router
=
Router
(
db
,
logger
=
log
,
migrate_dir
=
migrate_dir
)
router
.
run
()
db
.
close
()
# check if db connection has been closed
except
Exception
as
e
:
log
.
error
(
f
"Failed to initialize the database connection:
{
e
}
"
)
raise
router
=
Router
(
DB
,
migrate_dir
=
BACKEND_DIR
/
"apps"
/
"webui"
/
"internal"
/
"migrations"
,
logger
=
log
,
finally
:
# Properly closing the database connection
if
db
and
not
db
.
is_closed
():
db
.
close
()
# Assert if db connection has been closed
assert
db
.
is_closed
(),
"Database connection is still open."
handle_peewee_migration
(
DATABASE_URL
)
SQLALCHEMY_DATABASE_URL
=
DATABASE_URL
if
"sqlite"
in
SQLALCHEMY_DATABASE_URL
:
engine
=
create_engine
(
SQLALCHEMY_DATABASE_URL
,
connect_args
=
{
"check_same_thread"
:
False
}
)
else
:
engine
=
create_engine
(
SQLALCHEMY_DATABASE_URL
,
pool_pre_ping
=
True
)
SessionLocal
=
sessionmaker
(
autocommit
=
False
,
autoflush
=
False
,
bind
=
engine
,
expire_on_commit
=
False
)
router
.
run
()
try
:
DB
.
connect
(
reuse_if_open
=
True
)
except
OperationalError
as
e
:
log
.
info
(
f
"Failed to connect to database again due to:
{
e
}
"
)
pass
Base
=
declarative_base
()
Session
=
scoped_session
(
SessionLocal
)
# Dependency
def
get_session
():
db
=
SessionLocal
()
try
:
yield
db
finally
:
db
.
close
()
get_db
=
contextmanager
(
get_session
)
backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py
View file @
f9e3c47d
"""Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
...
...
@@ -21,7 +18,6 @@ Some examples (model - class or model name)::
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from
contextlib
import
suppress
...
...
backend/apps/webui/internal/migrations/README.md
deleted
100644 → 0
View file @
49b4211c
# Database Migrations
This directory contains all the database migrations for the web app.
Migrations are done using the
[
`peewee-migrate`
](
https://github.com/klen/peewee_migrate
)
library.
Migrations are automatically ran at app startup.
## Creating a migration
Have you made a change to the schema of an existing model?
You will need to create a migration file to ensure that existing databases are updated for backwards compatibility.
1.
Have a database file (
`webui.db`
) that has the old schema prior to any of your changes.
2.
Make your changes to the models.
3.
From the
`backend`
directory, run the following command:
```
bash
pw_migrate create
--auto
--auto-source
apps.webui.models
--database
sqlite:///
${
SQLITE_DB
}
--directory
apps/web/internal/migrations
${
MIGRATION_NAME
}
```
-
`$SQLITE_DB`
should be the path to the database file.
-
`$MIGRATION_NAME`
should be a descriptive name for the migration.
4.
The migration file will be created in the
`apps/web/internal/migrations`
directory.
backend/apps/webui/main.py
View file @
f9e3c47d
...
...
@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
from
fastapi.responses
import
StreamingResponse
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
sqlalchemy.orm
import
Session
from
apps.webui.routers
import
(
auths
,
users
,
...
...
@@ -19,8 +19,13 @@ from apps.webui.routers import (
functions
,
)
from
apps.webui.models.functions
import
Functions
from
apps.webui.models.models
import
Models
from
apps.webui.utils
import
load_function_module_by_id
from
utils.misc
import
stream_message_template
from
utils.task
import
prompt_template
from
config
import
(
WEBUI_BUILD_HASH
,
...
...
@@ -39,6 +44,8 @@ from config import (
WEBUI_BANNERS
,
ENABLE_COMMUNITY_SHARING
,
AppConfig
,
OAUTH_USERNAME_CLAIM
,
OAUTH_PICTURE_CLAIM
,
)
import
inspect
...
...
@@ -74,6 +81,9 @@ app.state.config.BANNERS = WEBUI_BANNERS
app
.
state
.
config
.
ENABLE_COMMUNITY_SHARING
=
ENABLE_COMMUNITY_SHARING
app
.
state
.
config
.
OAUTH_USERNAME_CLAIM
=
OAUTH_USERNAME_CLAIM
app
.
state
.
config
.
OAUTH_PICTURE_CLAIM
=
OAUTH_PICTURE_CLAIM
app
.
state
.
MODELS
=
{}
app
.
state
.
TOOLS
=
{}
app
.
state
.
FUNCTIONS
=
{}
...
...
@@ -129,7 +139,6 @@ async def get_pipe_models():
function_module
=
app
.
state
.
FUNCTIONS
[
pipe
.
id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
print
(
f
"Getting valves for
{
pipe
.
id
}
"
)
valves
=
Functions
.
get_function_valves_by_id
(
pipe
.
id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
...
...
@@ -181,6 +190,77 @@ async def get_pipe_models():
async
def
generate_function_chat_completion
(
form_data
,
user
):
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
if
model_info
.
base_model_id
:
form_data
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
if
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
:
form_data
[
"temperature"
]
=
float
(
model_info
.
params
.
get
(
"temperature"
))
if
model_info
.
params
.
get
(
"top_p"
,
None
):
form_data
[
"top_p"
]
=
int
(
model_info
.
params
.
get
(
"top_p"
,
None
))
if
model_info
.
params
.
get
(
"max_tokens"
,
None
):
form_data
[
"max_tokens"
]
=
int
(
model_info
.
params
.
get
(
"max_tokens"
,
None
))
if
model_info
.
params
.
get
(
"frequency_penalty"
,
None
):
form_data
[
"frequency_penalty"
]
=
int
(
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
)
if
model_info
.
params
.
get
(
"seed"
,
None
):
form_data
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
if
model_info
.
params
.
get
(
"stop"
,
None
):
form_data
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if
form_data
.
get
(
"messages"
):
for
message
in
form_data
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
system
+
message
[
"content"
]
break
else
:
form_data
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
system
,
},
)
else
:
pass
async
def
job
():
pipe_id
=
form_data
[
"model"
]
if
"."
in
pipe_id
:
...
...
backend/apps/webui/models/auths.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
typing
import
List
,
Union
,
Optional
import
time
from
typing
import
Optional
import
uuid
import
logging
from
peewee
import
*
from
sqlalchemy
import
String
,
Column
,
Boolean
,
Text
from
apps.webui.models.users
import
UserModel
,
Users
from
utils.utils
import
verify_password
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
get_db
from
config
import
SRC_LOG_LEVELS
...
...
@@ -20,14 +19,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Auth
(
Model
):
id
=
CharField
(
unique
=
True
)
email
=
CharField
()
password
=
TextField
()
active
=
BooleanField
()
class
Auth
(
Base
):
__tablename__
=
"auth"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
email
=
Column
(
String
)
password
=
Column
(
Text
)
active
=
Column
(
Boolean
)
class
AuthModel
(
BaseModel
):
...
...
@@ -94,9 +92,6 @@ class AddUserForm(SignupForm):
class
AuthsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Auth
])
def
insert_new_auth
(
self
,
...
...
@@ -107,6 +102,8 @@ class AuthsTable:
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
with
get_db
()
as
db
:
log
.
info
(
"insert_new_auth"
)
id
=
str
(
uuid
.
uuid4
())
...
...
@@ -114,12 +111,16 @@ class AuthsTable:
auth
=
AuthModel
(
**
{
"id"
:
id
,
"email"
:
email
,
"password"
:
password
,
"active"
:
True
}
)
result
=
Auth
.
create
(
**
auth
.
model_dump
())
result
=
Auth
(
**
auth
.
model_dump
())
db
.
add
(
result
)
user
=
Users
.
insert_new_user
(
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
and
user
:
return
user
else
:
...
...
@@ -128,7 +129,9 @@ class AuthsTable:
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user:
{
email
}
"
)
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
with
get_db
()
as
db
:
auth
=
db
.
query
(
Auth
).
filter_by
(
email
=
email
,
active
=
True
).
first
()
if
auth
:
if
verify_password
(
password
,
auth
.
password
):
user
=
Users
.
get_user_by_id
(
auth
.
id
)
...
...
@@ -155,7 +158,8 @@ class AuthsTable:
def
authenticate_user_by_trusted_header
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user_by_trusted_header:
{
email
}
"
)
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
with
get_db
()
as
db
:
auth
=
db
.
query
(
Auth
).
filter
(
email
=
email
,
active
=
True
).
first
()
if
auth
:
user
=
Users
.
get_user_by_id
(
auth
.
id
)
return
user
...
...
@@ -164,31 +168,34 @@ class AuthsTable:
def
update_user_password_by_id
(
self
,
id
:
str
,
new_password
:
str
)
->
bool
:
try
:
query
=
Auth
.
update
(
password
=
new_password
).
where
(
Auth
.
id
==
id
)
result
=
query
.
execute
()
with
get_db
()
as
db
:
result
=
(
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"password"
:
new_password
})
)
db
.
commit
()
return
True
if
result
==
1
else
False
except
:
return
False
def
update_email_by_id
(
self
,
id
:
str
,
email
:
str
)
->
bool
:
try
:
query
=
Auth
.
update
(
email
=
email
).
where
(
Auth
.
id
==
id
)
result
=
query
.
execute
(
)
with
get_db
()
as
db
:
result
=
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"email"
:
email
}
)
db
.
commit
()
return
True
if
result
==
1
else
False
except
:
return
False
def
delete_auth_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_db
()
as
db
:
# Delete User
result
=
Users
.
delete_user_by_id
(
id
)
if
result
:
# Delete Auth
query
=
Auth
.
delete
().
where
(
Auth
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
else
:
...
...
@@ -197,4 +204,4 @@ class AuthsTable:
return
False
Auths
=
AuthsTable
(
DB
)
Auths
=
AuthsTable
()
backend/apps/webui/models/chats.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
import
json
import
uuid
import
time
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Boolean
,
Text
from
apps.webui.internal.db
import
Base
,
get_db
####################
# Chat DB Schema
####################
class
Chat
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
title
=
TextField
()
chat
=
TextField
()
# Save Chat JSON as Text
class
Chat
(
Base
):
__tablename__
=
"chat"
created_at
=
BigIntegerField
()
updated_at
=
BigIntegerField
()
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
title
=
Column
(
Text
)
chat
=
Column
(
Text
)
# Save Chat JSON as Text
share_id
=
CharField
(
null
=
True
,
unique
=
True
)
archived
=
BooleanField
(
default
=
False
)
created_at
=
Column
(
BigInteger
)
updated_at
=
Column
(
BigInteger
)
class
Meta
:
database
=
DB
share_id
=
Column
(
Text
,
unique
=
True
,
nullable
=
True
)
archived
=
Column
(
Boolean
,
default
=
False
)
class
ChatModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
id
:
str
user_id
:
str
title
:
str
...
...
@@ -75,18 +77,19 @@ class ChatTitleIdResponse(BaseModel):
class
ChatTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
db
.
create_tables
([
Chat
])
def
insert_new_chat
(
self
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
with
get_db
()
as
db
:
id
=
str
(
uuid
.
uuid4
())
chat
=
ChatModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"title"
:
(
form_data
.
chat
[
"title"
]
if
"title"
in
form_data
.
chat
else
"New Chat"
form_data
.
chat
[
"title"
]
if
"title"
in
form_data
.
chat
else
"New Chat"
),
"chat"
:
json
.
dumps
(
form_data
.
chat
),
"created_at"
:
int
(
time
.
time
()),
...
...
@@ -94,26 +97,32 @@ class ChatTable:
}
)
result
=
Chat
.
create
(
**
chat
.
model_dump
())
return
chat
if
result
else
None
result
=
Chat
(
**
chat
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
return
ChatModel
.
model_validate
(
result
)
if
result
else
None
def
update_chat_by_id
(
self
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
try
:
query
=
Chat
.
update
(
chat
=
json
.
dumps
(
chat
),
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
,
updated_at
=
int
(
time
.
time
()),
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
except
:
with
get_db
()
as
db
:
chat_obj
=
db
.
get
(
Chat
,
id
)
chat_obj
.
chat
=
json
.
dumps
(
chat
)
chat_obj
.
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
chat_obj
.
updated_at
=
int
(
time
.
time
())
db
.
commit
()
db
.
refresh
(
chat_obj
)
return
ChatModel
.
model_validate
(
chat_obj
)
except
Exception
as
e
:
return
None
def
insert_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
with
get_db
()
as
db
:
# Get the existing chat to share
chat
=
Chat
.
get
(
Chat
.
id
==
chat_id
)
chat
=
db
.
get
(
Chat
,
chat_id
)
# Check if the chat is already shared
if
chat
.
share_id
:
return
self
.
get_chat_by_id_and_user_id
(
chat
.
share_id
,
"shared"
)
...
...
@@ -128,36 +137,42 @@ class ChatTable:
"updated_at"
:
int
(
time
.
time
()),
}
)
shared_result
=
Chat
.
create
(
**
shared_chat
.
model_dump
())
shared_result
=
Chat
(
**
shared_chat
.
model_dump
())
db
.
add
(
shared_result
)
db
.
commit
()
db
.
refresh
(
shared_result
)
# Update the original chat with the share_id
result
=
(
Chat
.
update
(
share_id
=
shared_chat
.
id
).
where
(
Chat
.
id
==
chat_id
).
execute
()
db
.
query
(
Chat
)
.
filter_by
(
id
=
chat_id
)
.
update
({
"share_id"
:
shared_chat
.
id
})
)
db
.
commit
()
return
shared_chat
if
(
shared_result
and
result
)
else
None
def
update_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
with
get_db
()
as
db
:
print
(
"update_shared_chat_by_id"
)
chat
=
Chat
.
get
(
Chat
.
id
==
chat_id
)
chat
=
db
.
get
(
Chat
,
chat_id
)
print
(
chat
)
chat
.
title
=
chat
.
title
chat
.
chat
=
chat
.
chat
db
.
commit
()
db
.
refresh
(
chat
)
query
=
Chat
.
update
(
title
=
chat
.
title
,
chat
=
chat
.
chat
,
).
where
(
Chat
.
id
==
chat
.
share_id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
chat
.
share_id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
return
self
.
get_chat_by_id
(
chat
.
share_id
)
except
:
return
None
def
delete_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
bool
:
try
:
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
==
f
"shared-
{
chat_id
}
"
)
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
user_id
=
f
"shared-
{
chat_id
}
"
).
delete
()
db
.
commit
()
return
True
except
:
...
...
@@ -167,40 +182,33 @@ class ChatTable:
self
,
id
:
str
,
share_id
:
Optional
[
str
]
)
->
Optional
[
ChatModel
]:
try
:
query
=
Chat
.
update
(
share_id
=
share_id
,
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
with
get_db
()
as
db
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
chat
=
db
.
get
(
Chat
,
id
)
chat
.
share_id
=
share_id
db
.
commit
()
db
.
refresh
(
chat
)
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
toggle_chat_archive_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
self
.
get_chat_by_id
(
id
)
query
=
Chat
.
update
(
archived
=
(
not
chat
.
archived
),
).
where
(
Chat
.
id
==
id
)
with
get_db
()
as
db
:
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
chat
=
db
.
get
(
Chat
,
id
)
chat
.
archived
=
not
chat
.
archived
db
.
commit
()
db
.
refresh
(
chat
)
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
archive_all_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
chats
=
self
.
get_chats_by_user_id
(
user_id
)
for
chat
in
chats
:
query
=
Chat
.
update
(
archived
=
True
,
).
where
(
Chat
.
id
==
chat
.
id
)
query
.
execute
()
with
get_db
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
update
({
"archived"
:
True
})
db
.
commit
()
return
True
except
:
return
False
...
...
@@ -208,15 +216,16 @@ class ChatTable:
def
get_archived_chat_list_by_user_id
(
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
wh
er
e
(
Chat
.
archived
==
True
)
.
where
(
Chat
.
user_id
==
user_id
)
with
get_db
()
as
db
:
all_chats
=
(
db
.
qu
er
y
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
# .limit(limit).offset(skip)
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_list_by_user_id
(
self
,
...
...
@@ -225,92 +234,97 @@ class ChatTable:
skip
:
int
=
0
,
limit
:
int
=
50
,
)
->
List
[
ChatModel
]:
if
include_archived
:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
else
:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
False
)
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
with
get_db
()
as
db
:
query
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
)
if
not
include_archived
:
query
=
query
.
filter_by
(
archived
=
False
)
all_chats
=
(
query
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit).offset(skip)
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_list_by_chat_ids
(
self
,
chat_ids
:
List
[
str
],
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
(
)
.
wh
er
e
(
Chat
.
archived
==
False
)
.
where
(
Chat
.
id
.
in_
(
chat_ids
)
)
with
get_db
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
.
filt
er
(
Chat
.
id
.
in_
(
chat_ids
)
)
.
filter_by
(
archived
=
False
)
.
order_by
(
Chat
.
updated_at
.
desc
())
]
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
with
get_db
()
as
db
:
chat
=
db
.
get
(
Chat
,
id
)
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
get_chat_by_share_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
Chat
.
get
(
Chat
.
share_id
==
id
)
with
get_db
()
as
db
:
chat
=
db
.
query
(
Chat
).
filter_by
(
share_id
=
id
).
first
()
if
chat
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
return
self
.
get_chat_by_id
(
id
)
else
:
return
None
except
:
except
Exception
as
e
:
return
None
def
get_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
,
Chat
.
user_id
==
user_id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
with
get_db
()
as
db
:
chat
=
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
first
()
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
get_chats
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
().
order_by
(
Chat
.
updated_at
.
desc
())
with
get_db
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
# .limit(limit).offset(skip)
]
.
order_by
(
Chat
.
updated_at
.
desc
())
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
user_id
==
user_id
)
with
get_db
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit).offset(skip
)
]
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_archived_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
wh
er
e
(
Chat
.
archived
==
True
)
.
where
(
Chat
.
user_id
==
user_id
)
with
get_db
()
as
db
:
all_chats
=
(
db
.
qu
er
y
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
]
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
delete_chat_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
...
...
@@ -318,8 +332,10 @@ class ChatTable:
def
delete_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
)
&
(
Chat
.
user_id
==
user_id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
db
.
commit
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
...
...
@@ -328,10 +344,12 @@ class ChatTable:
def
delete_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
with
get_db
()
as
db
:
self
.
delete_shared_chats_by_user_id
(
user_id
)
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
==
user_id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
)
.
delete
()
db
.
commit
()
return
True
except
:
...
...
@@ -339,17 +357,18 @@ class ChatTable:
def
delete_shared_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
Chat
.
select
().
where
(
Chat
.
user_id
==
user_id
)
]
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
<<
shared_chat_ids
)
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
chats_by_user
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
all
()
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
chats_by_user
]
db
.
query
(
Chat
).
filter
(
Chat
.
user_id
.
in_
(
shared_chat_ids
)).
delete
()
db
.
commit
()
return
True
except
:
return
False
Chats
=
ChatTable
(
DB
)
Chats
=
ChatTable
()
backend/apps/webui/models/documents.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
import
logging
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
get_db
import
json
...
...
@@ -22,20 +19,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Document
(
Model
):
collection_name
=
CharField
(
unique
=
True
)
name
=
CharField
(
unique
=
True
)
title
=
TextField
()
filename
=
TextField
()
content
=
TextField
(
null
=
True
)
user_id
=
CharField
()
timestamp
=
BigIntegerField
()
class
Document
(
Base
):
__tablename__
=
"document"
class
Meta
:
database
=
DB
collection_name
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
,
unique
=
True
)
title
=
Column
(
Text
)
filename
=
Column
(
Text
)
content
=
Column
(
Text
,
nullable
=
True
)
user_id
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
DocumentModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
collection_name
:
str
name
:
str
title
:
str
...
...
@@ -72,13 +70,12 @@ class DocumentForm(DocumentUpdateForm):
class
DocumentsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Document
])
def
insert_new_doc
(
self
,
user_id
:
str
,
form_data
:
DocumentForm
)
->
Optional
[
DocumentModel
]:
with
get_db
()
as
db
:
document
=
DocumentModel
(
**
{
**
form_data
.
model_dump
(),
...
...
@@ -88,9 +85,12 @@ class DocumentsTable:
)
try
:
result
=
Document
.
create
(
**
document
.
model_dump
())
result
=
Document
(
**
document
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
d
ocument
return
D
ocument
Model
.
model_validate
(
result
)
else
:
return
None
except
:
...
...
@@ -98,31 +98,35 @@ class DocumentsTable:
def
get_doc_by_name
(
self
,
name
:
str
)
->
Optional
[
DocumentModel
]:
try
:
document
=
Document
.
get
(
Document
.
name
==
name
)
return
DocumentModel
(
**
model_to_dict
(
document
))
with
get_db
()
as
db
:
document
=
db
.
query
(
Document
).
filter_by
(
name
=
name
).
first
()
return
DocumentModel
.
model_validate
(
document
)
if
document
else
None
except
:
return
None
def
get_docs
(
self
)
->
List
[
DocumentModel
]:
with
get_db
()
as
db
:
return
[
DocumentModel
(
**
model_to_dict
(
doc
))
for
doc
in
Document
.
select
()
# .limit(limit).offset(skip)
DocumentModel
.
model_validate
(
doc
)
for
doc
in
db
.
query
(
Document
).
all
()
]
def
update_doc_by_name
(
self
,
name
:
str
,
form_data
:
DocumentUpdateForm
)
->
Optional
[
DocumentModel
]:
try
:
query
=
Document
.
update
(
title
=
form_data
.
title
,
name
=
form_data
.
name
,
timestamp
=
int
(
time
.
time
()),
).
where
(
Document
.
name
==
name
)
query
.
execute
()
doc
=
Document
.
get
(
Document
.
name
==
form_data
.
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
with
get_db
()
as
db
:
db
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
{
"title"
:
form_data
.
title
,
"name"
:
form_data
.
name
,
"timestamp"
:
int
(
time
.
time
()),
}
)
db
.
commit
()
return
self
.
get_doc_by_name
(
form_data
.
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
return
None
...
...
@@ -135,26 +139,29 @@ class DocumentsTable:
doc_content
=
json
.
loads
(
doc
.
content
if
doc
.
content
else
"{}"
)
doc_content
=
{
**
doc_content
,
**
updated
}
query
=
Document
.
update
(
content
=
json
.
dumps
(
doc_content
),
timestamp
=
int
(
time
.
time
()),
).
where
(
Document
.
name
==
name
)
query
.
execute
()
with
get_db
()
as
db
:
doc
=
Document
.
get
(
Document
.
name
==
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
db
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
{
"content"
:
json
.
dumps
(
doc_content
),
"timestamp"
:
int
(
time
.
time
()),
}
)
db
.
commit
()
return
self
.
get_doc_by_name
(
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
return
None
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
try
:
query
=
Document
.
delete
().
where
((
Document
.
name
==
name
))
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
db
.
query
(
Document
).
filter_by
(
name
=
name
).
delete
()
db
.
commit
()
return
True
except
:
return
False
Documents
=
DocumentsTable
(
DB
)
Documents
=
DocumentsTable
()
backend/apps/webui/models/files.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Text
from
apps.webui.internal.db
import
JSONField
,
Base
,
get_db
import
json
...
...
@@ -18,15 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
File
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
filename
=
TextField
()
meta
=
JSONField
()
created_at
=
BigIntegerField
()
class
File
(
Base
):
__tablename__
=
"file"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
filename
=
Column
(
Text
)
meta
=
Column
(
JSONField
)
created_at
=
Column
(
BigInteger
)
class
FileModel
(
BaseModel
):
...
...
@@ -36,6 +36,8 @@ class FileModel(BaseModel):
meta
:
dict
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -57,11 +59,10 @@ class FileForm(BaseModel):
class
FilesTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
File
])
def
insert_new_file
(
self
,
user_id
:
str
,
form_data
:
FileForm
)
->
Optional
[
FileModel
]:
with
get_db
()
as
db
:
file
=
FileModel
(
**
{
**
form_data
.
model_dump
(),
...
...
@@ -71,9 +72,12 @@ class FilesTable:
)
try
:
result
=
File
.
create
(
**
file
.
model_dump
())
result
=
File
(
**
file
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
f
ile
return
F
ile
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -81,32 +85,42 @@ class FilesTable:
return
None
def
get_file_by_id
(
self
,
id
:
str
)
->
Optional
[
FileModel
]:
with
get_db
()
as
db
:
try
:
file
=
File
.
get
(
File
.
id
==
id
)
return
FileModel
(
**
model_
to_dict
(
file
)
)
file
=
db
.
get
(
File
,
id
)
return
FileModel
.
model_
validate
(
file
)
except
:
return
None
def
get_files
(
self
)
->
List
[
FileModel
]:
return
[
FileModel
(
**
model_to_dict
(
file
))
for
file
in
File
.
select
()]
with
get_db
()
as
db
:
return
[
FileModel
.
model_validate
(
file
)
for
file
in
db
.
query
(
File
).
all
()]
def
delete_file_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
query
=
File
.
delete
().
where
((
File
.
id
==
id
)
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
File
).
filter_by
(
id
=
id
).
delete
(
)
db
.
commit
()
return
True
except
:
return
False
def
delete_all_files
(
self
)
->
bool
:
with
get_db
()
as
db
:
try
:
query
=
File
.
delete
()
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
File
)
.
delete
()
db
.
commit
()
return
True
except
:
return
False
Files
=
FilesTable
(
DB
)
Files
=
FilesTable
()
backend/apps/webui/models/functions.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
Column
,
String
,
Text
,
BigInteger
,
Boolean
from
apps.webui.internal.db
import
JSONField
,
Base
,
get_db
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -21,21 +22,20 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Function
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
name
=
TextField
()
type
=
TextField
()
content
=
TextField
()
meta
=
JSONField
()
valves
=
JSONField
()
is_active
=
BooleanField
(
default
=
False
)
is_global
=
BooleanField
(
default
=
False
)
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Function
(
Base
):
__tablename__
=
"function"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
name
=
Column
(
Text
)
type
=
Column
(
Text
)
content
=
Column
(
Text
)
meta
=
Column
(
JSONField
)
valves
=
Column
(
JSONField
)
is_active
=
Column
(
Boolean
)
is_global
=
Column
(
Boolean
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
FunctionMeta
(
BaseModel
):
...
...
@@ -55,6 +55,8 @@ class FunctionModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -85,13 +87,11 @@ class FunctionValves(BaseModel):
class
FunctionsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Function
])
def
insert_new_function
(
self
,
user_id
:
str
,
type
:
str
,
form_data
:
FunctionForm
)
->
Optional
[
FunctionModel
]:
function
=
FunctionModel
(
**
{
**
form_data
.
model_dump
(),
...
...
@@ -103,9 +103,13 @@ class FunctionsTable:
)
try
:
result
=
Function
.
create
(
**
function
.
model_dump
())
with
get_db
()
as
db
:
result
=
Function
(
**
function
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
f
unction
return
F
unction
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -114,52 +118,60 @@ class FunctionsTable:
def
get_function_by_id
(
self
,
id
:
str
)
->
Optional
[
FunctionModel
]:
try
:
function
=
Function
.
get
(
Function
.
id
==
id
)
return
FunctionModel
(
**
model_to_dict
(
function
))
with
get_db
()
as
db
:
function
=
db
.
get
(
Function
,
id
)
return
FunctionModel
.
model_validate
(
function
)
except
:
return
None
def
get_functions
(
self
,
active_only
=
False
)
->
List
[
FunctionModel
]:
with
get_db
()
as
db
:
if
active_only
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
().
where
(
Function
.
is_active
==
True
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
db
.
query
(
Function
).
filter_by
(
is_active
=
True
)
.
all
()
]
else
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
()
FunctionModel
.
model_
validate
(
function
)
for
function
in
db
.
query
(
Function
).
all
()
]
def
get_functions_by_type
(
self
,
type
:
str
,
active_only
=
False
)
->
List
[
FunctionModel
]:
with
get_db
()
as
db
:
if
active_only
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
().
where
(
Function
.
type
==
type
,
Function
.
is_active
==
True
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
db
.
query
(
Function
)
.
filter_by
(
type
=
type
,
is_active
=
True
)
.
all
(
)
]
else
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
().
where
(
Function
.
type
==
type
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
db
.
query
(
Function
).
filter_by
(
type
=
type
).
all
(
)
]
def
get_global_filter_functions
(
self
)
->
List
[
FunctionModel
]:
with
get_db
()
as
db
:
return
[
FunctionModel
(
**
model_to_dict
(
function
))
for
function
in
Function
.
select
().
where
(
Function
.
type
==
"filter"
,
Function
.
is_active
==
True
,
Function
.
is_global
==
True
,
)
FunctionModel
.
model_validate
(
function
)
for
function
in
db
.
query
(
Function
)
.
filter_by
(
type
=
"filter"
,
is_active
=
True
,
is_global
=
True
)
.
all
()
]
def
get_function_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
with
get_db
()
as
db
:
try
:
function
=
Function
.
get
(
Function
.
id
==
id
)
function
=
db
.
get
(
Function
,
id
)
return
function
.
valves
if
function
.
valves
else
{}
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
...
...
@@ -168,24 +180,25 @@ class FunctionsTable:
def
update_function_valves_by_id
(
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
FunctionValves
]:
with
get_db
()
as
db
:
try
:
query
=
Function
.
update
(
**
{
"valves"
:
valves
},
updated_at
=
int
(
time
.
time
()),
).
where
(
Function
.
id
==
id
)
query
.
execute
()
function
=
Function
.
get
(
Function
.
id
==
id
)
return
FunctionValves
(
**
model_to_dict
(
function
))
function
=
db
.
get
(
Function
,
id
)
function
.
valves
=
valves
function
.
updated_at
=
int
(
time
.
time
())
db
.
commit
()
db
.
refresh
(
function
)
return
self
.
get_function_by_id
(
id
)
except
:
return
None
def
get_user_valves_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
dict
]:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "functions" and "valves" settings
if
"functions"
not
in
user_settings
:
...
...
@@ -201,9 +214,10 @@ class FunctionsTable:
def
update_user_valves_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
,
valves
:
dict
)
->
Optional
[
dict
]:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "functions" and "valves" settings
if
"functions"
not
in
user_settings
:
...
...
@@ -222,39 +236,44 @@ class FunctionsTable:
return
None
def
update_function_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
FunctionModel
]:
with
get_db
()
as
db
:
try
:
query
=
Function
.
update
(
db
.
query
(
Function
).
filter_by
(
id
=
id
).
update
(
{
**
updated
,
updated_at
=
int
(
time
.
time
()),
).
where
(
Function
.
id
==
id
)
query
.
execute
()
function
=
Function
.
get
(
Function
.
id
==
id
)
return
FunctionModel
(
**
model_to_dict
(
function
))
"updated_at"
:
int
(
time
.
time
()),
}
)
db
.
commit
()
return
self
.
get_function_by_id
(
id
)
except
:
return
None
def
deactivate_all_functions
(
self
)
->
Optional
[
bool
]:
with
get_db
()
as
db
:
try
:
query
=
Function
.
update
(
**
{
"is_active"
:
False
},
updated_at
=
int
(
time
.
time
()),
db
.
query
(
Function
).
update
(
{
"is_active"
:
False
,
"updated_at"
:
int
(
time
.
time
()),
}
)
query
.
execute
()
db
.
commit
()
return
True
except
:
return
None
def
delete_function_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
query
=
Function
.
delete
().
where
((
Function
.
id
==
id
)
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Function
).
filter_by
(
id
=
id
).
delete
(
)
db
.
commit
()
return
True
except
:
return
False
Functions
=
FunctionsTable
(
DB
)
Functions
=
FunctionsTable
()
backend/apps/webui/models/memories.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
from
apps.webui.internal.db
import
DB
from
apps.webui.models.chats
import
Chats
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
get_db
import
time
import
uuid
...
...
@@ -14,15 +13,14 @@ import uuid
####################
class
Memory
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
content
=
TextField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Memory
(
Base
):
__tablename__
=
"memory"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
content
=
Column
(
Text
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
MemoryModel
(
BaseModel
):
...
...
@@ -32,6 +30,8 @@ class MemoryModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -39,15 +39,14 @@ class MemoryModel(BaseModel):
class
MemoriesTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Memory
])
def
insert_new_memory
(
self
,
user_id
:
str
,
content
:
str
,
)
->
Optional
[
MemoryModel
]:
with
get_db
()
as
db
:
id
=
str
(
uuid
.
uuid4
())
memory
=
MemoryModel
(
...
...
@@ -59,9 +58,12 @@ class MemoriesTable:
"updated_at"
:
int
(
time
.
time
()),
}
)
result
=
Memory
.
create
(
**
memory
.
model_dump
())
result
=
Memory
(
**
memory
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
m
emory
return
M
emory
Model
.
model_validate
(
result
)
else
:
return
None
...
...
@@ -70,40 +72,50 @@ class MemoriesTable:
id
:
str
,
content
:
str
,
)
->
Optional
[
MemoryModel
]:
with
get_db
()
as
db
:
try
:
memory
=
Memory
.
get
(
Memory
.
id
==
id
)
memory
.
content
=
content
memory
.
updated_at
=
int
(
time
.
time
()
)
memory
.
save
()
return
MemoryModel
(
**
model_to_dict
(
memory
)
)
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
update
(
{
"
content
"
:
content
,
"updated_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
return
self
.
get_memory_by_id
(
id
)
except
:
return
None
def
get_memories
(
self
)
->
List
[
MemoryModel
]:
with
get_db
()
as
db
:
try
:
memories
=
Memory
.
select
()
return
[
MemoryModel
(
**
model_
to_dict
(
memory
)
)
for
memory
in
memories
]
memories
=
db
.
query
(
Memory
).
all
()
return
[
MemoryModel
.
model_
validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memories_by_user_id
(
self
,
user_id
:
str
)
->
List
[
MemoryModel
]:
with
get_db
()
as
db
:
try
:
memories
=
Memory
.
select
().
where
(
Memory
.
user_id
==
user_id
)
return
[
MemoryModel
(
**
model_
to_dict
(
memory
)
)
for
memory
in
memories
]
memories
=
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
)
.
all
()
return
[
MemoryModel
.
model_
validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memory_by_id
(
self
,
id
)
->
Optional
[
MemoryModel
]:
def
get_memory_by_id
(
self
,
id
:
str
)
->
Optional
[
MemoryModel
]:
with
get_db
()
as
db
:
try
:
memory
=
Memory
.
get
(
Memory
.
id
==
id
)
return
MemoryModel
(
**
model_
to_dict
(
memory
)
)
memory
=
db
.
get
(
Memory
,
id
)
return
MemoryModel
.
model_
validate
(
memory
)
except
:
return
None
def
delete_memory_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
delete
(
)
db
.
commit
()
return
True
...
...
@@ -111,22 +123,26 @@ class MemoriesTable:
return
False
def
delete_memories_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
user_id
==
user_id
)
query
.
execute
()
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
)
.
delete
()
db
.
commit
()
return
True
except
:
return
False
def
delete_memory_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
,
Memory
.
user_id
==
user_id
)
query
.
execute
()
db
.
query
(
Memory
).
filter_by
(
id
=
id
,
user_id
=
user_id
)
.
delete
()
db
.
commit
()
return
True
except
:
return
False
Memories
=
MemoriesTable
(
DB
)
Memories
=
MemoriesTable
()
backend/apps/webui/models/models.py
View file @
f9e3c47d
...
...
@@ -2,13 +2,10 @@ import json
import
logging
from
typing
import
Optional
import
peewee
as
pw
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
,
get_db
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
...
...
@@ -32,7 +29,7 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table
class
ModelMeta
(
BaseModel
):
profile_image_url
:
Optional
[
str
]
=
"/favicon.png"
profile_image_url
:
Optional
[
str
]
=
"/
static/
favicon.png"
description
:
Optional
[
str
]
=
None
"""
...
...
@@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
pass
class
Model
(
pw
.
Model
):
id
=
pw
.
TextField
(
unique
=
True
)
class
Model
(
Base
):
__tablename__
=
"model"
id
=
Column
(
Text
,
primary_key
=
True
)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id
=
pw
.
TextField
(
)
user_id
=
Column
(
Text
)
base_model_id
=
pw
.
TextField
(
null
=
True
)
base_model_id
=
Column
(
Text
,
nullable
=
True
)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name
=
pw
.
TextField
(
)
name
=
Column
(
Text
)
"""
The human-readable display name of the model.
"""
params
=
JSONField
(
)
params
=
Column
(
JSONField
)
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta
=
JSONField
(
)
meta
=
Column
(
JSONField
)
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Meta
:
database
=
DB
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
ModelModel
(
BaseModel
):
...
...
@@ -92,6 +88,8 @@ class ModelModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -115,12 +113,6 @@ class ModelForm(BaseModel):
class
ModelsTable
:
def
__init__
(
self
,
db
:
pw
.
SqliteDatabase
|
pw
.
PostgresqlDatabase
,
):
self
.
db
=
db
self
.
db
.
create_tables
([
Model
])
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
...
...
@@ -134,10 +126,16 @@ class ModelsTable:
}
)
try
:
result
=
Model
.
create
(
**
model
.
model_dump
())
with
get_db
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
model
return
ModelModel
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -145,23 +143,33 @@ class ModelsTable:
return
None
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
return
[
ModelModel
(
**
model_to_dict
(
model
))
for
model
in
Model
.
select
()]
with
get_db
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_to_dict
(
model
))
with
get_db
()
as
db
:
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
except
:
return
None
def
update_model_by_id
(
self
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
try
:
with
get_db
()
as
db
:
# update only the fields that are present in the model
query
=
Model
.
update
(
**
model
.
model_dump
()).
where
(
Model
.
id
==
id
)
query
.
execute
()
result
=
(
db
.
query
(
Model
)
.
filter_by
(
id
=
id
)
.
update
(
model
.
model_dump
(
exclude
=
{
"id"
},
exclude_none
=
True
))
)
db
.
commit
()
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_to_dict
(
model
))
model
=
db
.
get
(
Model
,
id
)
db
.
refresh
(
model
)
return
ModelModel
.
model_validate
(
model
)
except
Exception
as
e
:
print
(
e
)
...
...
@@ -169,11 +177,14 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Model
.
delete
().
where
(
Model
.
id
==
id
)
query
.
execute
()
with
get_db
()
as
db
:
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
except
:
return
False
Models
=
ModelsTable
(
DB
)
Models
=
ModelsTable
()
backend/apps/webui/models/prompts.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
get_db
import
json
...
...
@@ -16,15 +13,14 @@ import json
####################
class
Prompt
(
Model
):
command
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
title
=
TextField
()
content
=
TextField
()
timestamp
=
BigIntegerField
()
class
Prompt
(
Base
):
__tablename__
=
"prompt"
class
Meta
:
database
=
DB
command
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
title
=
Column
(
Text
)
content
=
Column
(
Text
)
timestamp
=
Column
(
BigInteger
)
class
PromptModel
(
BaseModel
):
...
...
@@ -34,6 +30,8 @@ class PromptModel(BaseModel):
content
:
str
timestamp
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -48,10 +46,6 @@ class PromptForm(BaseModel):
class
PromptsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Prompt
])
def
insert_new_prompt
(
self
,
user_id
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
...
...
@@ -66,53 +60,60 @@ class PromptsTable:
)
try
:
result
=
Prompt
.
create
(
**
prompt
.
model_dump
())
with
get_db
()
as
db
:
result
=
Prompt
(
**
prompt
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
p
rompt
return
P
rompt
Model
.
model_validate
(
result
)
else
:
return
None
except
:
except
Exception
as
e
:
return
None
def
get_prompt_by_command
(
self
,
command
:
str
)
->
Optional
[
PromptModel
]:
try
:
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
return
PromptModel
(
**
model_to_dict
(
prompt
))
with
get_db
()
as
db
:
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
return
PromptModel
.
model_validate
(
prompt
)
except
:
return
None
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
with
get_db
()
as
db
:
return
[
PromptModel
(
**
model_to_dict
(
prompt
))
for
prompt
in
Prompt
.
select
()
# .limit(limit).offset(skip)
PromptModel
.
model_validate
(
prompt
)
for
prompt
in
db
.
query
(
Prompt
).
all
()
]
def
update_prompt_by_command
(
self
,
command
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
try
:
query
=
Prompt
.
update
(
title
=
form_data
.
title
,
content
=
form_data
.
content
,
timestamp
=
int
(
time
.
time
()),
).
where
(
Prompt
.
command
==
command
)
query
.
execute
()
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
return
PromptModel
(
**
model_to_dict
(
prompt
))
with
get_db
()
as
db
:
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
prompt
.
title
=
form_data
.
title
prompt
.
content
=
form_data
.
content
prompt
.
timestamp
=
int
(
time
.
time
())
db
.
commit
()
return
PromptModel
.
model_validate
(
prompt
)
except
:
return
None
def
delete_prompt_by_command
(
self
,
command
:
str
)
->
bool
:
try
:
query
=
Prompt
.
delete
().
where
((
Prompt
.
command
==
command
))
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
delete
()
db
.
commit
()
return
True
except
:
return
False
Prompts
=
PromptsTable
(
DB
)
Prompts
=
PromptsTable
()
backend/apps/webui/models/tags.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
typing
import
List
,
Union
,
Optional
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
json
import
uuid
import
time
import
logging
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
get_db
from
config
import
SRC_LOG_LEVELS
...
...
@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Tag
(
Model
):
id
=
CharField
(
unique
=
True
)
name
=
CharField
()
user_id
=
CharField
()
data
=
TextField
(
null
=
True
)
class
Tag
(
Base
):
__tablename__
=
"tag"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
)
user_id
=
Column
(
String
)
data
=
Column
(
Text
,
nullable
=
True
)
class
ChatIdTag
(
Model
):
id
=
CharField
(
unique
=
True
)
tag_name
=
CharField
()
chat_id
=
CharField
()
user_id
=
CharField
()
timestamp
=
BigIntegerField
()
class
ChatIdTag
(
Base
):
__tablename__
=
"chatidtag"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
tag_name
=
Column
(
String
)
chat_id
=
Column
(
String
)
user_id
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
TagModel
(
BaseModel
):
...
...
@@ -47,6 +45,8 @@ class TagModel(BaseModel):
user_id
:
str
data
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
class
ChatIdTagModel
(
BaseModel
):
id
:
str
...
...
@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
user_id
:
str
timestamp
:
int
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -75,17 +77,19 @@ class ChatTagsResponse(BaseModel):
class
TagTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
db
.
create_tables
([
Tag
,
ChatIdTag
])
def
insert_new_tag
(
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
with
get_db
()
as
db
:
id
=
str
(
uuid
.
uuid4
())
tag
=
TagModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"name"
:
name
})
try
:
result
=
Tag
.
create
(
**
tag
.
model_dump
())
result
=
Tag
(
**
tag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
t
ag
return
T
ag
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -95,8 +99,9 @@ class TagTable:
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
try
:
tag
=
Tag
.
get
(
Tag
.
name
==
name
,
Tag
.
user_id
==
user_id
)
return
TagModel
(
**
model_to_dict
(
tag
))
with
get_db
()
as
db
:
tag
=
db
.
query
(
Tag
).
filter
(
name
=
name
,
user_id
=
user_id
).
first
()
return
TagModel
.
model_validate
(
tag
)
except
Exception
as
e
:
return
None
...
...
@@ -118,81 +123,109 @@ class TagTable:
}
)
try
:
result
=
ChatIdTag
.
create
(
**
chatIdTag
.
model_dump
())
with
get_db
()
as
db
:
result
=
ChatIdTag
(
**
chatIdTag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
c
hatIdTag
return
C
hatIdTag
Model
.
model_validate
(
result
)
else
:
return
None
except
:
return
None
def
get_tags_by_user_id
(
self
,
user_id
:
str
)
->
List
[
TagModel
]:
with
get_db
()
as
db
:
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
(
ChatIdTag
.
user_id
==
user_id
)
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
return
[
TagModel
(
**
model_to_dict
(
tag
))
for
tag
in
Tag
.
select
()
.
where
(
Tag
.
user_id
==
user_id
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
TagModel
.
model_validate
(
tag
)
for
tag
in
(
db
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
def
get_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
List
[
TagModel
]:
with
get_db
()
as
db
:
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
chat_id
==
chat_id
))
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
chat_id
=
chat_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
return
[
TagModel
(
**
model_to_dict
(
tag
))
for
tag
in
Tag
.
select
()
.
where
(
Tag
.
user_id
==
user_id
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
TagModel
.
model_validate
(
tag
)
for
tag
in
(
db
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
def
get_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
Optional
[
ChatIdTagModel
]:
)
->
List
[
ChatIdTagModel
]:
with
get_db
()
as
db
:
return
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
))
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
tag_name
==
tag_name
))
ChatIdTagModel
.
model_validate
(
chat_id_tag
)
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
tag_name
=
tag_name
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
def
count_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
int
:
with
get_db
()
as
db
:
return
(
ChatIdTag
.
select
(
)
.
where
((
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
count
()
)
def
delete_tag_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
with
get_db
()
as
db
:
res
=
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
delete
()
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
db
.
commit
()
return
True
except
Exception
as
e
:
log
.
error
(
f
"delete_tag:
{
e
}
"
)
...
...
@@ -202,21 +235,23 @@ class TagTable:
self
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
chat_id
==
chat_id
)
&
(
ChatIdTag
.
user_id
==
user_id
)
with
get_db
()
as
db
:
res
=
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
chat_id
=
chat_id
,
user_id
=
user_id
)
.
delete
()
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
db
.
commit
()
return
True
except
Exception
as
e
:
...
...
@@ -234,4 +269,4 @@ class TagTable:
return
True
Tags
=
TagTable
(
DB
)
Tags
=
TagTable
()
backend/apps/webui/models/tools.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Tool
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
name
=
TextField
()
content
=
TextField
()
specs
=
JSONField
()
meta
=
JSONField
()
valves
=
JSONField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Tool
(
Base
):
__tablename__
=
"tool"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
name
=
Column
(
Text
)
content
=
Column
(
Text
)
specs
=
Column
(
JSONField
)
meta
=
Column
(
JSONField
)
valves
=
Column
(
JSONField
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
ToolMeta
(
BaseModel
):
...
...
@@ -51,6 +50,8 @@ class ToolModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -78,13 +79,13 @@ class ToolValves(BaseModel):
class
ToolsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Tool
])
def
insert_new_tool
(
self
,
user_id
:
str
,
form_data
:
ToolForm
,
specs
:
List
[
dict
]
)
->
Optional
[
ToolModel
]:
with
get_db
()
as
db
:
tool
=
ToolModel
(
**
{
**
form_data
.
model_dump
(),
...
...
@@ -96,9 +97,12 @@ class ToolsTable:
)
try
:
result
=
Tool
.
create
(
**
tool
.
model_dump
())
result
=
Tool
(
**
tool
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
tool
return
ToolModel
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -107,17 +111,22 @@ class ToolsTable:
def
get_tool_by_id
(
self
,
id
:
str
)
->
Optional
[
ToolModel
]:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolModel
(
**
model_to_dict
(
tool
))
with
get_db
()
as
db
:
tool
=
db
.
get
(
Tool
,
id
)
return
ToolModel
.
model_validate
(
tool
)
except
:
return
None
def
get_tools
(
self
)
->
List
[
ToolModel
]:
return
[
ToolModel
(
**
model_to_dict
(
tool
))
for
tool
in
Tool
.
select
()]
with
get_db
()
as
db
:
return
[
ToolModel
.
model_validate
(
tool
)
for
tool
in
db
.
query
(
Tool
).
all
()]
def
get_tool_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
with
get_db
()
as
db
:
tool
=
db
.
get
(
Tool
,
id
)
return
tool
.
valves
if
tool
.
valves
else
{}
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
...
...
@@ -125,14 +134,13 @@ class ToolsTable:
def
update_tool_valves_by_id
(
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
ToolValves
]:
try
:
query
=
Tool
.
update
(
**
{
"valves"
:
valves
},
updated_at
=
int
(
time
.
time
()),
).
where
(
Tool
.
id
==
id
)
query
.
execute
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolValves
(
**
model_to_dict
(
tool
))
with
get_db
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
"valves"
:
valves
,
"updated_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
return
self
.
get_tool_by_id
(
id
)
except
:
return
None
...
...
@@ -141,7 +149,7 @@ class ToolsTable:
)
->
Optional
[
dict
]:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "tools" and "valves" settings
if
"tools"
not
in
user_settings
:
...
...
@@ -159,7 +167,7 @@ class ToolsTable:
)
->
Optional
[
dict
]:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "tools" and "valves" settings
if
"tools"
not
in
user_settings
:
...
...
@@ -179,25 +187,27 @@ class ToolsTable:
def
update_tool_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
ToolModel
]:
try
:
query
=
Tool
.
update
(
**
updated
,
updated_at
=
int
(
time
.
time
()),
).
where
(
Tool
.
id
==
id
)
query
.
execute
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolModel
(
**
model_to_dict
(
tool
))
with
get_db
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
**
updated
,
"updated_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
tool
=
db
.
query
(
Tool
).
get
(
id
)
db
.
refresh
(
tool
)
return
ToolModel
.
model_validate
(
tool
)
except
:
return
None
def
delete_tool_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Tool
.
delete
().
where
((
Tool
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
except
:
return
False
Tools
=
ToolsTable
(
DB
)
Tools
=
ToolsTable
()
backend/apps/webui/models/users.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
,
parse_obj_as
from
typing
import
List
,
Union
,
Optional
import
time
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
utils.misc
import
get_gravatar_url
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
,
Session
,
get_db
from
apps.webui.models.chats
import
Chats
####################
...
...
@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
####################
class
User
(
Model
):
id
=
CharField
(
unique
=
True
)
name
=
CharField
()
email
=
CharField
()
role
=
CharField
()
profile_image_url
=
TextField
()
class
User
(
Base
):
__tablename__
=
"user"
last_active_at
=
BigIntegerField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
id
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
)
email
=
Column
(
String
)
role
=
Column
(
String
)
profile_image_url
=
Column
(
Text
)
api_key
=
CharField
(
null
=
True
,
unique
=
True
)
settings
=
JSONField
(
null
=
True
)
info
=
JSONField
(
null
=
True
)
last_active_at
=
Column
(
BigInteger
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
oauth_sub
=
TextField
(
null
=
True
,
unique
=
True
)
api_key
=
Column
(
String
,
nullable
=
True
,
unique
=
True
)
settings
=
Column
(
JSONField
,
nullable
=
True
)
info
=
Column
(
JSONField
,
nullable
=
True
)
class
Meta
:
database
=
DB
oauth_sub
=
Column
(
Text
,
unique
=
True
)
class
UserSettings
(
BaseModel
):
...
...
@@ -57,6 +57,8 @@ class UserModel(BaseModel):
oauth_sub
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
class
UsersTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
User
])
def
insert_new_user
(
self
,
...
...
@@ -89,6 +88,7 @@ class UsersTable:
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
with
get_db
()
as
db
:
user
=
UserModel
(
**
{
"id"
:
id
,
...
...
@@ -102,7 +102,10 @@ class UsersTable:
"oauth_sub"
:
oauth_sub
,
}
)
result
=
User
.
create
(
**
user
.
model_dump
())
result
=
User
(
**
user
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
user
else
:
...
...
@@ -110,56 +113,67 @@ class UsersTable:
def
get_user_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
Exception
as
e
:
return
None
def
get_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
api_key
==
api_key
)
return
UserModel
(
**
model_to_dict
(
user
))
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
api_key
=
api_key
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
email
==
email
)
return
UserModel
(
**
model_to_dict
(
user
))
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
email
=
email
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_user_by_oauth_sub
(
self
,
sub
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
oauth_sub
==
sub
)
return
UserModel
(
**
model_to_dict
(
user
))
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
oauth_sub
=
sub
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_users
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
return
[
UserModel
(
**
model_to_dict
(
user
))
for
user
in
User
.
select
()
# .limit(limit).offset(skip)
]
with
get_db
()
as
db
:
users
=
(
db
.
query
(
User
)
# .offset(skip).limit(limit)
.
all
()
)
return
[
UserModel
.
model_validate
(
user
)
for
user
in
users
]
def
get_num_users
(
self
)
->
Optional
[
int
]:
return
User
.
select
().
count
()
with
get_db
()
as
db
:
return
db
.
query
(
User
).
count
()
def
get_first_user
(
self
)
->
UserModel
:
try
:
user
=
User
.
select
().
order_by
(
User
.
created_at
).
first
()
return
UserModel
(
**
model_to_dict
(
user
))
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
order_by
(
User
.
created_at
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
update_user_role_by_id
(
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
role
=
role
).
where
(
User
.
id
==
id
)
query
.
execute
(
)
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
with
get_db
()
as
db
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"role"
:
role
}
)
db
.
commit
()
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
...
...
@@ -167,23 +181,28 @@ class UsersTable:
self
,
id
:
str
,
profile_image_url
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
profile_image_url
=
profile_image_url
).
where
(
User
.
id
==
id
with
get_db
()
as
db
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"profile_image_url"
:
profile_image_url
}
)
query
.
execute
()
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_last_active_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
last_active_at
=
int
(
time
.
time
())).
where
(
User
.
id
==
id
)
query
.
execute
()
with
get_db
()
as
db
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"last_active_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
...
...
@@ -191,22 +210,25 @@ class UsersTable:
self
,
id
:
str
,
oauth_sub
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
oauth_sub
=
oauth_sub
).
where
(
User
.
id
==
id
)
query
.
execute
()
with
get_db
()
as
db
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"oauth_sub"
:
oauth_sub
})
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
**
updated
).
where
(
User
.
id
==
id
)
query
.
execute
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
with
get_db
()
as
db
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
updated
)
db
.
commit
()
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
# return UserModel(**user.dict())
except
Exception
as
e
:
return
None
def
delete_user_by_id
(
self
,
id
:
str
)
->
bool
:
...
...
@@ -215,9 +237,10 @@ class UsersTable:
result
=
Chats
.
delete_chats_by_user_id
(
id
)
if
result
:
with
get_db
()
as
db
:
# Delete User
query
=
User
.
delete
().
where
(
User
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
User
).
filter_by
(
id
=
id
).
delete
(
)
db
.
commit
()
return
True
else
:
...
...
@@ -227,19 +250,20 @@ class UsersTable:
def
update_user_api_key_by_id
(
self
,
id
:
str
,
api_key
:
str
)
->
str
:
try
:
query
=
User
.
update
(
api_key
=
api_key
).
where
(
User
.
id
==
id
)
result
=
query
.
execute
(
)
with
get_db
()
as
db
:
result
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"api_key"
:
api_key
}
)
db
.
commit
()
return
True
if
result
==
1
else
False
except
:
return
False
def
get_user_api_key_by_id
(
self
,
id
:
str
)
->
Optional
[
str
]:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
user
.
api_key
except
:
except
Exception
as
e
:
return
None
Users
=
UsersTable
(
DB
)
Users
=
UsersTable
()
backend/apps/webui/routers/chats.py
View file @
f9e3c47d
...
...
@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@
router
.
get
(
"/list/user/{user_id}"
,
response_model
=
List
[
ChatTitleIdResponse
])
async
def
get_user_chat_list_by_user_id
(
user_id
:
str
,
user
=
Depends
(
get_admin_user
),
skip
:
int
=
0
,
limit
:
int
=
50
user_id
:
str
,
user
=
Depends
(
get_admin_user
),
skip
:
int
=
0
,
limit
:
int
=
50
,
):
return
Chats
.
get_chat_list_by_user_id
(
user_id
,
include_archived
=
True
,
skip
=
skip
,
limit
=
limit
...
...
@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
@
router
.
get
(
"/all/archived"
,
response_model
=
List
[
ChatResponse
])
async
def
get_user_chats
(
user
=
Depends
(
get_verified_user
)):
async
def
get_user_
archived_
chats
(
user
=
Depends
(
get_verified_user
)):
return
[
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
for
chat
in
Chats
.
get_archived_chats_by_user_id
(
user
.
id
)
...
...
@@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name(
form_data
:
TagNameForm
,
user
=
Depends
(
get_verified_user
)
):
print
(
form_data
)
chat_ids
=
[
chat_id_tag
.
chat_id
for
chat_id_tag
in
Tags
.
get_chat_ids_by_tag_name_and_user_id
(
...
...
backend/apps/webui/routers/documents.py
View file @
f9e3c47d
...
...
@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
@
router
.
post
(
"/doc/update"
,
response_model
=
Optional
[
DocumentResponse
])
async
def
update_doc_by_name
(
name
:
str
,
form_data
:
DocumentUpdateForm
,
user
=
Depends
(
get_admin_user
)
name
:
str
,
form_data
:
DocumentUpdateForm
,
user
=
Depends
(
get_admin_user
),
):
doc
=
Documents
.
update_doc_by_name
(
name
,
form_data
)
if
doc
:
...
...
Prev
1
2
3
4
5
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment