Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
f9e3c47d
Commit
f9e3c47d
authored
Jul 09, 2024
by
Michael Poluektov
Browse files
rebase
parents
49b4211c
24ef5af2
Changes
149
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1370 additions
and
951 deletions
+1370
-951
.github/workflows/integration-test.yml
.github/workflows/integration-test.yml
+32
-6
backend/alembic.ini
backend/alembic.ini
+114
-0
backend/apps/rag/main.py
backend/apps/rag/main.py
+2
-1
backend/apps/webui/internal/db.py
backend/apps/webui/internal/db.py
+80
-24
backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py
.../apps/webui/internal/migrations/017_add_user_oauth_sub.py
+0
-4
backend/apps/webui/internal/migrations/README.md
backend/apps/webui/internal/migrations/README.md
+0
-21
backend/apps/webui/main.py
backend/apps/webui/main.py
+82
-2
backend/apps/webui/models/auths.py
backend/apps/webui/models/auths.py
+63
-56
backend/apps/webui/models/chats.py
backend/apps/webui/models/chats.py
+206
-187
backend/apps/webui/models/documents.py
backend/apps/webui/models/documents.py
+69
-62
backend/apps/webui/models/files.py
backend/apps/webui/models/files.py
+64
-50
backend/apps/webui/models/functions.py
backend/apps/webui/models/functions.py
+125
-106
backend/apps/webui/models/memories.py
backend/apps/webui/models/memories.py
+92
-76
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+53
-42
backend/apps/webui/models/prompts.py
backend/apps/webui/models/prompts.py
+47
-46
backend/apps/webui/models/tags.py
backend/apps/webui/models/tags.py
+141
-106
backend/apps/webui/models/tools.py
backend/apps/webui/models/tools.py
+74
-64
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+118
-94
backend/apps/webui/routers/chats.py
backend/apps/webui/routers/chats.py
+5
-3
backend/apps/webui/routers/documents.py
backend/apps/webui/routers/documents.py
+3
-1
No files found.
.github/workflows/integration-test.yml
View file @
f9e3c47d
...
@@ -35,6 +35,10 @@ jobs:
...
@@ -35,6 +35,10 @@ jobs:
done
done
echo "Service is up!"
echo "Service is up!"
-
name
:
Delete Docker build cache
run
:
|
docker builder prune --all --force
-
name
:
Preload Ollama model
-
name
:
Preload Ollama model
run
:
|
run
:
|
docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
...
@@ -43,7 +47,7 @@ jobs:
...
@@ -43,7 +47,7 @@ jobs:
uses
:
cypress-io/github-action@v6
uses
:
cypress-io/github-action@v6
with
:
with
:
browser
:
chrome
browser
:
chrome
wait-on
:
"
http://localhost:3000
"
wait-on
:
'
http://localhost:3000
'
config
:
baseUrl=http://localhost:3000
config
:
baseUrl=http://localhost:3000
-
uses
:
actions/upload-artifact@v4
-
uses
:
actions/upload-artifact@v4
...
@@ -67,6 +71,28 @@ jobs:
...
@@ -67,6 +71,28 @@ jobs:
path
:
compose-logs.txt
path
:
compose-logs.txt
if-no-files-found
:
ignore
if-no-files-found
:
ignore
# pytest:
# name: Run Backend Tests
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python
# uses: actions/setup-python@v4
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# pip install -r backend/requirements.txt
# - name: pytest run
# run: |
# ls -al
# cd backend
# PYTHONPATH=. pytest . -o log_cli=true -o log_cli_level=INFO
migration_test
:
migration_test
:
name
:
Run Migration Tests
name
:
Run Migration Tests
runs-on
:
ubuntu-latest
runs-on
:
ubuntu-latest
...
@@ -126,11 +152,11 @@ jobs:
...
@@ -126,11 +152,11 @@ jobs:
cd backend
cd backend
uvicorn main:app --port "8080" --forwarded-allow-ips '*' &
uvicorn main:app --port "8080" --forwarded-allow-ips '*' &
UVICORN_PID=$!
UVICORN_PID=$!
# Wait up to
2
0 seconds for the server to start
# Wait up to
4
0 seconds for the server to start
for i in {1..
2
0}; do
for i in {1..
4
0}; do
curl -s http://localhost:8080/api/config > /dev/null && break
curl -s http://localhost:8080/api/config > /dev/null && break
sleep 1
sleep 1
if [ $i -eq
2
0 ]; then
if [ $i -eq
4
0 ]; then
echo "Server failed to start"
echo "Server failed to start"
kill -9 $UVICORN_PID
kill -9 $UVICORN_PID
exit 1
exit 1
...
@@ -171,7 +197,7 @@ jobs:
...
@@ -171,7 +197,7 @@ jobs:
fi
fi
# Check that service will reconnect to postgres when connection will be closed
# Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health
/db
)
if [[ "$status_code" -ne 200 ]] ; then
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check"
echo "Server has failed before postgres reconnect check"
exit 1
exit 1
...
@@ -183,7 +209,7 @@ jobs:
...
@@ -183,7 +209,7 @@ jobs:
cur = conn.cursor(); \
cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health
/db
)
if [[ "$status_code" -ne 200 ]] ; then
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1
exit 1
...
...
backend/alembic.ini
0 → 100644
View file @
f9e3c47d
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location
=
migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path
=
.
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator
=
os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = REPLACE_WITH_DATABASE_URL
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys
=
root,sqlalchemy,alembic
[handlers]
keys
=
console
[formatters]
keys
=
generic
[logger_root]
level
=
WARN
handlers
=
console
qualname
=
[logger_sqlalchemy]
level
=
WARN
handlers
=
qualname
=
sqlalchemy.engine
[logger_alembic]
level
=
INFO
handlers
=
qualname
=
alembic
[handler_console]
class
=
StreamHandler
args
=
(sys.stderr,)
level
=
NOTSET
formatter
=
generic
[formatter_generic]
format
=
%(levelname)-5.5s [%(name)s] %(message)s
datefmt
=
%H:%M:%S
backend/apps/rag/main.py
View file @
f9e3c47d
...
@@ -1004,10 +1004,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
...
@@ -1004,10 +1004,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
return
True
return
True
log
.
exception
(
e
)
return
False
return
False
...
...
backend/apps/webui/internal/db.py
View file @
f9e3c47d
import
os
import
os
import
logging
import
logging
import
json
import
json
from
contextlib
import
contextmanager
from
peewee
import
*
from
peewee_migrate
import
Router
from
peewee_migrate
import
Router
from
apps.webui.internal.wrappers
import
register_connection
from
apps.webui.internal.wrappers
import
register_connection
from
typing
import
Optional
,
Any
from
typing_extensions
import
Self
from
sqlalchemy
import
create_engine
,
types
,
Dialect
from
sqlalchemy.ext.declarative
import
declarative_base
from
sqlalchemy.orm
import
sessionmaker
,
scoped_session
from
sqlalchemy.sql.type_api
import
_T
from
config
import
SRC_LOG_LEVELS
,
DATA_DIR
,
DATABASE_URL
,
BACKEND_DIR
from
config
import
SRC_LOG_LEVELS
,
DATA_DIR
,
DATABASE_URL
,
BACKEND_DIR
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
class
JSONField
(
TextField
):
class
JSONField
(
types
.
TypeDecorator
):
impl
=
types
.
Text
cache_ok
=
True
def
process_bind_param
(
self
,
value
:
Optional
[
_T
],
dialect
:
Dialect
)
->
Any
:
return
json
.
dumps
(
value
)
def
process_result_value
(
self
,
value
:
Optional
[
_T
],
dialect
:
Dialect
)
->
Any
:
if
value
is
not
None
:
return
json
.
loads
(
value
)
def
copy
(
self
,
**
kw
:
Any
)
->
Self
:
return
JSONField
(
self
.
impl
.
length
)
def
db_value
(
self
,
value
):
def
db_value
(
self
,
value
):
return
json
.
dumps
(
value
)
return
json
.
dumps
(
value
)
...
@@ -30,25 +51,60 @@ else:
...
@@ -30,25 +51,60 @@ else:
pass
pass
# The `register_connection` function encapsulates the logic for setting up
# Workaround to handle the peewee migration
# the database connection based on the connection string, while `connect`
# This is required to ensure the peewee migration is handled before the alembic migration
# is a Peewee-specific method to manage the connection state and avoid errors
def
handle_peewee_migration
(
DATABASE_URL
):
# when a connection is already open.
try
:
try
:
# Replace the postgresql:// with postgres:// and %40 with @ in the DATABASE_URL
DB
=
register_connection
(
DATABASE_URL
)
db
=
register_connection
(
log
.
info
(
f
"Connected to a
{
DB
.
__class__
.
__name__
}
database."
)
DATABASE_URL
.
replace
(
"postgresql://"
,
"postgres://"
).
replace
(
"%40"
,
"@"
)
except
Exception
as
e
:
)
migrate_dir
=
BACKEND_DIR
/
"apps"
/
"webui"
/
"internal"
/
"migrations"
router
=
Router
(
db
,
logger
=
log
,
migrate_dir
=
migrate_dir
)
router
.
run
()
db
.
close
()
# check if db connection has been closed
except
Exception
as
e
:
log
.
error
(
f
"Failed to initialize the database connection:
{
e
}
"
)
log
.
error
(
f
"Failed to initialize the database connection:
{
e
}
"
)
raise
raise
router
=
Router
(
finally
:
DB
,
# Properly closing the database connection
migrate_dir
=
BACKEND_DIR
/
"apps"
/
"webui"
/
"internal"
/
"migrations"
,
if
db
and
not
db
.
is_closed
():
logger
=
log
,
db
.
close
()
# Assert if db connection has been closed
assert
db
.
is_closed
(),
"Database connection is still open."
handle_peewee_migration
(
DATABASE_URL
)
SQLALCHEMY_DATABASE_URL
=
DATABASE_URL
if
"sqlite"
in
SQLALCHEMY_DATABASE_URL
:
engine
=
create_engine
(
SQLALCHEMY_DATABASE_URL
,
connect_args
=
{
"check_same_thread"
:
False
}
)
else
:
engine
=
create_engine
(
SQLALCHEMY_DATABASE_URL
,
pool_pre_ping
=
True
)
SessionLocal
=
sessionmaker
(
autocommit
=
False
,
autoflush
=
False
,
bind
=
engine
,
expire_on_commit
=
False
)
)
router
.
run
()
Base
=
declarative_base
()
try
:
Session
=
scoped_session
(
SessionLocal
)
DB
.
connect
(
reuse_if_open
=
True
)
except
OperationalError
as
e
:
log
.
info
(
f
"Failed to connect to database again due to:
{
e
}
"
)
# Dependency
pass
def
get_session
():
db
=
SessionLocal
()
try
:
yield
db
finally
:
db
.
close
()
get_db
=
contextmanager
(
get_session
)
backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py
View file @
f9e3c47d
"""Peewee migrations -- 017_add_user_oauth_sub.py.
"""Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name)::
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.create_model(Model) # Create a model (could be used as decorator)
...
@@ -21,7 +18,6 @@ Some examples (model - class or model name)::
...
@@ -21,7 +18,6 @@ Some examples (model - class or model name)::
> migrator.drop_index(model, *col_names)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
> migrator.drop_constraints(model, *constraints)
"""
"""
from
contextlib
import
suppress
from
contextlib
import
suppress
...
...
backend/apps/webui/internal/migrations/README.md
deleted
100644 → 0
View file @
49b4211c
# Database Migrations
This directory contains all the database migrations for the web app.
Migrations are done using the
[
`peewee-migrate`
](
https://github.com/klen/peewee_migrate
)
library.
Migrations are automatically ran at app startup.
## Creating a migration
Have you made a change to the schema of an existing model?
You will need to create a migration file to ensure that existing databases are updated for backwards compatibility.
1.
Have a database file (
`webui.db`
) that has the old schema prior to any of your changes.
2.
Make your changes to the models.
3.
From the
`backend`
directory, run the following command:
```
bash
pw_migrate create
--auto
--auto-source
apps.webui.models
--database
sqlite:///
${
SQLITE_DB
}
--directory
apps/web/internal/migrations
${
MIGRATION_NAME
}
```
-
`$SQLITE_DB`
should be the path to the database file.
-
`$MIGRATION_NAME`
should be a descriptive name for the migration.
4.
The migration file will be created in the
`apps/web/internal/migrations`
directory.
backend/apps/webui/main.py
View file @
f9e3c47d
...
@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
...
@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
sqlalchemy.orm
import
Session
from
apps.webui.routers
import
(
from
apps.webui.routers
import
(
auths
,
auths
,
users
,
users
,
...
@@ -19,8 +19,13 @@ from apps.webui.routers import (
...
@@ -19,8 +19,13 @@ from apps.webui.routers import (
functions
,
functions
,
)
)
from
apps.webui.models.functions
import
Functions
from
apps.webui.models.functions
import
Functions
from
apps.webui.models.models
import
Models
from
apps.webui.utils
import
load_function_module_by_id
from
apps.webui.utils
import
load_function_module_by_id
from
utils.misc
import
stream_message_template
from
utils.misc
import
stream_message_template
from
utils.task
import
prompt_template
from
config
import
(
from
config
import
(
WEBUI_BUILD_HASH
,
WEBUI_BUILD_HASH
,
...
@@ -39,6 +44,8 @@ from config import (
...
@@ -39,6 +44,8 @@ from config import (
WEBUI_BANNERS
,
WEBUI_BANNERS
,
ENABLE_COMMUNITY_SHARING
,
ENABLE_COMMUNITY_SHARING
,
AppConfig
,
AppConfig
,
OAUTH_USERNAME_CLAIM
,
OAUTH_PICTURE_CLAIM
,
)
)
import
inspect
import
inspect
...
@@ -74,6 +81,9 @@ app.state.config.BANNERS = WEBUI_BANNERS
...
@@ -74,6 +81,9 @@ app.state.config.BANNERS = WEBUI_BANNERS
app
.
state
.
config
.
ENABLE_COMMUNITY_SHARING
=
ENABLE_COMMUNITY_SHARING
app
.
state
.
config
.
ENABLE_COMMUNITY_SHARING
=
ENABLE_COMMUNITY_SHARING
app
.
state
.
config
.
OAUTH_USERNAME_CLAIM
=
OAUTH_USERNAME_CLAIM
app
.
state
.
config
.
OAUTH_PICTURE_CLAIM
=
OAUTH_PICTURE_CLAIM
app
.
state
.
MODELS
=
{}
app
.
state
.
MODELS
=
{}
app
.
state
.
TOOLS
=
{}
app
.
state
.
TOOLS
=
{}
app
.
state
.
FUNCTIONS
=
{}
app
.
state
.
FUNCTIONS
=
{}
...
@@ -129,7 +139,6 @@ async def get_pipe_models():
...
@@ -129,7 +139,6 @@ async def get_pipe_models():
function_module
=
app
.
state
.
FUNCTIONS
[
pipe
.
id
]
function_module
=
app
.
state
.
FUNCTIONS
[
pipe
.
id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
print
(
f
"Getting valves for
{
pipe
.
id
}
"
)
valves
=
Functions
.
get_function_valves_by_id
(
pipe
.
id
)
valves
=
Functions
.
get_function_valves_by_id
(
pipe
.
id
)
function_module
.
valves
=
function_module
.
Valves
(
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
**
(
valves
if
valves
else
{})
...
@@ -181,6 +190,77 @@ async def get_pipe_models():
...
@@ -181,6 +190,77 @@ async def get_pipe_models():
async
def
generate_function_chat_completion
(
form_data
,
user
):
async
def
generate_function_chat_completion
(
form_data
,
user
):
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
if
model_info
.
base_model_id
:
form_data
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
if
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
:
form_data
[
"temperature"
]
=
float
(
model_info
.
params
.
get
(
"temperature"
))
if
model_info
.
params
.
get
(
"top_p"
,
None
):
form_data
[
"top_p"
]
=
int
(
model_info
.
params
.
get
(
"top_p"
,
None
))
if
model_info
.
params
.
get
(
"max_tokens"
,
None
):
form_data
[
"max_tokens"
]
=
int
(
model_info
.
params
.
get
(
"max_tokens"
,
None
))
if
model_info
.
params
.
get
(
"frequency_penalty"
,
None
):
form_data
[
"frequency_penalty"
]
=
int
(
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
)
if
model_info
.
params
.
get
(
"seed"
,
None
):
form_data
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
if
model_info
.
params
.
get
(
"stop"
,
None
):
form_data
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if
form_data
.
get
(
"messages"
):
for
message
in
form_data
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
system
+
message
[
"content"
]
break
else
:
form_data
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
system
,
},
)
else
:
pass
async
def
job
():
async
def
job
():
pipe_id
=
form_data
[
"model"
]
pipe_id
=
form_data
[
"model"
]
if
"."
in
pipe_id
:
if
"."
in
pipe_id
:
...
...
backend/apps/webui/models/auths.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
typing
import
List
,
Union
,
Optional
from
typing
import
Optional
import
time
import
uuid
import
uuid
import
logging
import
logging
from
peewee
import
*
from
sqlalchemy
import
String
,
Column
,
Boolean
,
Text
from
apps.webui.models.users
import
UserModel
,
Users
from
apps.webui.models.users
import
UserModel
,
Users
from
utils.utils
import
verify_password
from
utils.utils
import
verify_password
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
get_db
from
config
import
SRC_LOG_LEVELS
from
config
import
SRC_LOG_LEVELS
...
@@ -20,14 +19,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
...
@@ -20,14 +19,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
####################
class
Auth
(
Model
):
class
Auth
(
Base
):
id
=
CharField
(
unique
=
True
)
__tablename__
=
"auth"
email
=
CharField
()
password
=
TextField
()
active
=
BooleanField
()
class
Meta
:
id
=
Column
(
String
,
primary_key
=
True
)
database
=
DB
email
=
Column
(
String
)
password
=
Column
(
Text
)
active
=
Column
(
Boolean
)
class
AuthModel
(
BaseModel
):
class
AuthModel
(
BaseModel
):
...
@@ -94,9 +92,6 @@ class AddUserForm(SignupForm):
...
@@ -94,9 +92,6 @@ class AddUserForm(SignupForm):
class
AuthsTable
:
class
AuthsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Auth
])
def
insert_new_auth
(
def
insert_new_auth
(
self
,
self
,
...
@@ -107,6 +102,8 @@ class AuthsTable:
...
@@ -107,6 +102,8 @@ class AuthsTable:
role
:
str
=
"pending"
,
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
)
->
Optional
[
UserModel
]:
with
get_db
()
as
db
:
log
.
info
(
"insert_new_auth"
)
log
.
info
(
"insert_new_auth"
)
id
=
str
(
uuid
.
uuid4
())
id
=
str
(
uuid
.
uuid4
())
...
@@ -114,12 +111,16 @@ class AuthsTable:
...
@@ -114,12 +111,16 @@ class AuthsTable:
auth
=
AuthModel
(
auth
=
AuthModel
(
**
{
"id"
:
id
,
"email"
:
email
,
"password"
:
password
,
"active"
:
True
}
**
{
"id"
:
id
,
"email"
:
email
,
"password"
:
password
,
"active"
:
True
}
)
)
result
=
Auth
.
create
(
**
auth
.
model_dump
())
result
=
Auth
(
**
auth
.
model_dump
())
db
.
add
(
result
)
user
=
Users
.
insert_new_user
(
user
=
Users
.
insert_new_user
(
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
)
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
and
user
:
if
result
and
user
:
return
user
return
user
else
:
else
:
...
@@ -128,7 +129,9 @@ class AuthsTable:
...
@@ -128,7 +129,9 @@ class AuthsTable:
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user:
{
email
}
"
)
log
.
info
(
f
"authenticate_user:
{
email
}
"
)
try
:
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
with
get_db
()
as
db
:
auth
=
db
.
query
(
Auth
).
filter_by
(
email
=
email
,
active
=
True
).
first
()
if
auth
:
if
auth
:
if
verify_password
(
password
,
auth
.
password
):
if
verify_password
(
password
,
auth
.
password
):
user
=
Users
.
get_user_by_id
(
auth
.
id
)
user
=
Users
.
get_user_by_id
(
auth
.
id
)
...
@@ -155,7 +158,8 @@ class AuthsTable:
...
@@ -155,7 +158,8 @@ class AuthsTable:
def
authenticate_user_by_trusted_header
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
def
authenticate_user_by_trusted_header
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user_by_trusted_header:
{
email
}
"
)
log
.
info
(
f
"authenticate_user_by_trusted_header:
{
email
}
"
)
try
:
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
with
get_db
()
as
db
:
auth
=
db
.
query
(
Auth
).
filter
(
email
=
email
,
active
=
True
).
first
()
if
auth
:
if
auth
:
user
=
Users
.
get_user_by_id
(
auth
.
id
)
user
=
Users
.
get_user_by_id
(
auth
.
id
)
return
user
return
user
...
@@ -164,31 +168,34 @@ class AuthsTable:
...
@@ -164,31 +168,34 @@ class AuthsTable:
def
update_user_password_by_id
(
self
,
id
:
str
,
new_password
:
str
)
->
bool
:
def
update_user_password_by_id
(
self
,
id
:
str
,
new_password
:
str
)
->
bool
:
try
:
try
:
query
=
Auth
.
update
(
password
=
new_password
).
where
(
Auth
.
id
==
id
)
with
get_db
()
as
db
:
result
=
query
.
execute
()
result
=
(
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"password"
:
new_password
})
)
db
.
commit
()
return
True
if
result
==
1
else
False
return
True
if
result
==
1
else
False
except
:
except
:
return
False
return
False
def
update_email_by_id
(
self
,
id
:
str
,
email
:
str
)
->
bool
:
def
update_email_by_id
(
self
,
id
:
str
,
email
:
str
)
->
bool
:
try
:
try
:
query
=
Auth
.
update
(
email
=
email
).
where
(
Auth
.
id
==
id
)
with
get_db
()
as
db
:
result
=
query
.
execute
(
)
result
=
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"email"
:
email
}
)
db
.
commit
()
return
True
if
result
==
1
else
False
return
True
if
result
==
1
else
False
except
:
except
:
return
False
return
False
def
delete_auth_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_auth_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
try
:
with
get_db
()
as
db
:
# Delete User
# Delete User
result
=
Users
.
delete_user_by_id
(
id
)
result
=
Users
.
delete_user_by_id
(
id
)
if
result
:
if
result
:
# Delete Auth
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
delete
()
query
=
Auth
.
delete
().
where
(
Auth
.
id
==
id
)
db
.
commit
()
query
.
execute
()
# Remove the rows, return number of rows removed.
return
True
return
True
else
:
else
:
...
@@ -197,4 +204,4 @@ class AuthsTable:
...
@@ -197,4 +204,4 @@ class AuthsTable:
return
False
return
False
Auths
=
AuthsTable
(
DB
)
Auths
=
AuthsTable
()
backend/apps/webui/models/chats.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
import
json
import
json
import
uuid
import
uuid
import
time
import
time
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Boolean
,
Text
from
apps.webui.internal.db
import
Base
,
get_db
####################
####################
# Chat DB Schema
# Chat DB Schema
####################
####################
class
Chat
(
Model
):
class
Chat
(
Base
):
id
=
CharField
(
unique
=
True
)
__tablename__
=
"chat"
user_id
=
CharField
()
title
=
TextField
()
chat
=
TextField
()
# Save Chat JSON as Text
created_at
=
BigIntegerField
()
id
=
Column
(
String
,
primary_key
=
True
)
updated_at
=
BigIntegerField
()
user_id
=
Column
(
String
)
title
=
Column
(
Text
)
chat
=
Column
(
Text
)
# Save Chat JSON as Text
share_id
=
CharField
(
null
=
True
,
unique
=
True
)
created_at
=
Column
(
BigInteger
)
archived
=
BooleanField
(
default
=
False
)
updated_at
=
Column
(
BigInteger
)
class
Meta
:
share_id
=
Column
(
Text
,
unique
=
True
,
nullable
=
True
)
database
=
DB
archived
=
Column
(
Boolean
,
default
=
False
)
class
ChatModel
(
BaseModel
):
class
ChatModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
id
:
str
id
:
str
user_id
:
str
user_id
:
str
title
:
str
title
:
str
...
@@ -75,18 +77,19 @@ class ChatTitleIdResponse(BaseModel):
...
@@ -75,18 +77,19 @@ class ChatTitleIdResponse(BaseModel):
class
ChatTable
:
class
ChatTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
db
.
create_tables
([
Chat
])
def
insert_new_chat
(
self
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
def
insert_new_chat
(
self
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
with
get_db
()
as
db
:
id
=
str
(
uuid
.
uuid4
())
id
=
str
(
uuid
.
uuid4
())
chat
=
ChatModel
(
chat
=
ChatModel
(
**
{
**
{
"id"
:
id
,
"id"
:
id
,
"user_id"
:
user_id
,
"user_id"
:
user_id
,
"title"
:
(
"title"
:
(
form_data
.
chat
[
"title"
]
if
"title"
in
form_data
.
chat
else
"New Chat"
form_data
.
chat
[
"title"
]
if
"title"
in
form_data
.
chat
else
"New Chat"
),
),
"chat"
:
json
.
dumps
(
form_data
.
chat
),
"chat"
:
json
.
dumps
(
form_data
.
chat
),
"created_at"
:
int
(
time
.
time
()),
"created_at"
:
int
(
time
.
time
()),
...
@@ -94,26 +97,32 @@ class ChatTable:
...
@@ -94,26 +97,32 @@ class ChatTable:
}
}
)
)
result
=
Chat
.
create
(
**
chat
.
model_dump
())
result
=
Chat
(
**
chat
.
model_dump
())
return
chat
if
result
else
None
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
return
ChatModel
.
model_validate
(
result
)
if
result
else
None
def
update_chat_by_id
(
self
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
def
update_chat_by_id
(
self
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
try
:
try
:
query
=
Chat
.
update
(
with
get_db
()
as
db
:
chat
=
json
.
dumps
(
chat
),
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
,
chat_obj
=
db
.
get
(
Chat
,
id
)
updated_at
=
int
(
time
.
time
()),
chat_obj
.
chat
=
json
.
dumps
(
chat
)
).
where
(
Chat
.
id
==
id
)
chat_obj
.
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
query
.
execute
()
chat_obj
.
updated_at
=
int
(
time
.
time
())
db
.
commit
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
db
.
refresh
(
chat_obj
)
return
ChatModel
(
**
model_to_dict
(
chat
))
except
:
return
ChatModel
.
model_validate
(
chat_obj
)
except
Exception
as
e
:
return
None
return
None
def
insert_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
def
insert_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
with
get_db
()
as
db
:
# Get the existing chat to share
# Get the existing chat to share
chat
=
Chat
.
get
(
Chat
.
id
==
chat_id
)
chat
=
db
.
get
(
Chat
,
chat_id
)
# Check if the chat is already shared
# Check if the chat is already shared
if
chat
.
share_id
:
if
chat
.
share_id
:
return
self
.
get_chat_by_id_and_user_id
(
chat
.
share_id
,
"shared"
)
return
self
.
get_chat_by_id_and_user_id
(
chat
.
share_id
,
"shared"
)
...
@@ -128,36 +137,42 @@ class ChatTable:
...
@@ -128,36 +137,42 @@ class ChatTable:
"updated_at"
:
int
(
time
.
time
()),
"updated_at"
:
int
(
time
.
time
()),
}
}
)
)
shared_result
=
Chat
.
create
(
**
shared_chat
.
model_dump
())
shared_result
=
Chat
(
**
shared_chat
.
model_dump
())
db
.
add
(
shared_result
)
db
.
commit
()
db
.
refresh
(
shared_result
)
# Update the original chat with the share_id
# Update the original chat with the share_id
result
=
(
result
=
(
Chat
.
update
(
share_id
=
shared_chat
.
id
).
where
(
Chat
.
id
==
chat_id
).
execute
()
db
.
query
(
Chat
)
.
filter_by
(
id
=
chat_id
)
.
update
({
"share_id"
:
shared_chat
.
id
})
)
)
db
.
commit
()
return
shared_chat
if
(
shared_result
and
result
)
else
None
return
shared_chat
if
(
shared_result
and
result
)
else
None
def
update_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
def
update_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
try
:
with
get_db
()
as
db
:
print
(
"update_shared_chat_by_id"
)
print
(
"update_shared_chat_by_id"
)
chat
=
Chat
.
get
(
Chat
.
id
==
chat_id
)
chat
=
db
.
get
(
Chat
,
chat_id
)
print
(
chat
)
print
(
chat
)
chat
.
title
=
chat
.
title
chat
.
chat
=
chat
.
chat
db
.
commit
()
db
.
refresh
(
chat
)
query
=
Chat
.
update
(
return
self
.
get_chat_by_id
(
chat
.
share_id
)
title
=
chat
.
title
,
chat
=
chat
.
chat
,
).
where
(
Chat
.
id
==
chat
.
share_id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
chat
.
share_id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
except
:
except
:
return
None
return
None
def
delete_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
bool
:
def
delete_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
bool
:
try
:
try
:
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
==
f
"shared-
{
chat_id
}
"
)
with
get_db
()
as
db
:
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Chat
).
filter_by
(
user_id
=
f
"shared-
{
chat_id
}
"
).
delete
()
db
.
commit
()
return
True
return
True
except
:
except
:
...
@@ -167,40 +182,33 @@ class ChatTable:
...
@@ -167,40 +182,33 @@ class ChatTable:
self
,
id
:
str
,
share_id
:
Optional
[
str
]
self
,
id
:
str
,
share_id
:
Optional
[
str
]
)
->
Optional
[
ChatModel
]:
)
->
Optional
[
ChatModel
]:
try
:
try
:
query
=
Chat
.
update
(
with
get_db
()
as
db
:
share_id
=
share_id
,
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
chat
=
db
.
get
(
Chat
,
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
chat
.
share_id
=
share_id
db
.
commit
()
db
.
refresh
(
chat
)
return
ChatModel
.
model_validate
(
chat
)
except
:
except
:
return
None
return
None
def
toggle_chat_archive_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
def
toggle_chat_archive_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
try
:
chat
=
self
.
get_chat_by_id
(
id
)
with
get_db
()
as
db
:
query
=
Chat
.
update
(
archived
=
(
not
chat
.
archived
),
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
chat
=
db
.
get
(
Chat
,
id
)
chat
.
archived
=
not
chat
.
archived
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
db
.
commit
()
return
ChatModel
(
**
model_to_dict
(
chat
))
db
.
refresh
(
chat
)
return
ChatModel
.
model_validate
(
chat
)
except
:
except
:
return
None
return
None
def
archive_all_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
def
archive_all_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
try
:
chats
=
self
.
get_chats_by_user_id
(
user_id
)
with
get_db
()
as
db
:
for
chat
in
chats
:
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
update
({
"archived"
:
True
})
query
=
Chat
.
update
(
db
.
commit
()
archived
=
True
,
).
where
(
Chat
.
id
==
chat
.
id
)
query
.
execute
()
return
True
return
True
except
:
except
:
return
False
return
False
...
@@ -208,15 +216,16 @@ class ChatTable:
...
@@ -208,15 +216,16 @@ class ChatTable:
def
get_archived_chat_list_by_user_id
(
def
get_archived_chat_list_by_user_id
(
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
)
->
List
[
ChatModel
]:
return
[
with
get_db
()
as
db
:
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
all_chats
=
(
.
wh
er
e
(
Chat
.
archived
==
True
)
db
.
qu
er
y
(
Chat
)
.
where
(
Chat
.
user_id
==
user_id
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .limit(limit).offset(skip)
# .offset(skip)
.
all
()
]
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_list_by_user_id
(
def
get_chat_list_by_user_id
(
self
,
self
,
...
@@ -225,92 +234,97 @@ class ChatTable:
...
@@ -225,92 +234,97 @@ class ChatTable:
skip
:
int
=
0
,
skip
:
int
=
0
,
limit
:
int
=
50
,
limit
:
int
=
50
,
)
->
List
[
ChatModel
]:
)
->
List
[
ChatModel
]:
if
include_archived
:
with
get_db
()
as
db
:
return
[
query
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
)
ChatModel
(
**
model_to_dict
(
chat
))
if
not
include_archived
:
for
chat
in
Chat
.
select
()
query
=
query
.
filter_by
(
archived
=
False
)
.
where
(
Chat
.
user_id
==
user_id
)
all_chats
=
(
.
order_by
(
Chat
.
updated_at
.
desc
())
query
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .limit(limit).offset(skip)
# .offset(skip)
.
all
()
]
)
else
:
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
False
)
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
def
get_chat_list_by_chat_ids
(
def
get_chat_list_by_chat_ids
(
self
,
chat_ids
:
List
[
str
],
skip
:
int
=
0
,
limit
:
int
=
50
self
,
chat_ids
:
List
[
str
],
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
)
->
List
[
ChatModel
]:
return
[
with
get_db
()
as
db
:
ChatModel
(
**
model_to_dict
(
chat
))
all_chats
=
(
for
chat
in
Chat
.
select
(
)
db
.
query
(
Chat
)
.
wh
er
e
(
Chat
.
archived
==
False
)
.
filt
er
(
Chat
.
id
.
in_
(
chat_ids
)
)
.
where
(
Chat
.
id
.
in_
(
chat_ids
)
)
.
filter_by
(
archived
=
False
)
.
order_by
(
Chat
.
updated_at
.
desc
())
.
order_by
(
Chat
.
updated_at
.
desc
())
]
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
def
get_chat_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
try
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
with
get_db
()
as
db
:
return
ChatModel
(
**
model_to_dict
(
chat
))
chat
=
db
.
get
(
Chat
,
id
)
return
ChatModel
.
model_validate
(
chat
)
except
:
except
:
return
None
return
None
def
get_chat_by_share_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
def
get_chat_by_share_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
try
:
chat
=
Chat
.
get
(
Chat
.
share_id
==
id
)
with
get_db
()
as
db
:
chat
=
db
.
query
(
Chat
).
filter_by
(
share_id
=
id
).
first
()
if
chat
:
if
chat
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
self
.
get_chat_by_id
(
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
else
:
else
:
return
None
return
None
except
:
except
Exception
as
e
:
return
None
return
None
def
get_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
def
get_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
try
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
,
Chat
.
user_id
==
user_id
)
with
get_db
()
as
db
:
return
ChatModel
(
**
model_to_dict
(
chat
))
chat
=
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
first
()
return
ChatModel
.
model_validate
(
chat
)
except
:
except
:
return
None
return
None
def
get_chats
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
def
get_chats
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
with
get_db
()
as
db
:
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
().
order_by
(
Chat
.
updated_at
.
desc
())
all_chats
=
(
db
.
query
(
Chat
)
# .limit(limit).offset(skip)
# .limit(limit).offset(skip)
]
.
order_by
(
Chat
.
updated_at
.
desc
())
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
def
get_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
return
[
with
get_db
()
as
db
:
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
all_chats
=
(
.
where
(
Chat
.
user_id
==
user_id
)
db
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit).offset(skip
)
)
]
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_archived_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
def
get_archived_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
return
[
with
get_db
()
as
db
:
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
all_chats
=
(
.
wh
er
e
(
Chat
.
archived
==
True
)
db
.
qu
er
y
(
Chat
)
.
where
(
Chat
.
user_id
==
user_id
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
.
order_by
(
Chat
.
updated_at
.
desc
())
]
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
delete_chat_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_chat_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
try
:
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
))
with
get_db
()
as
db
:
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
except
:
...
@@ -318,8 +332,10 @@ class ChatTable:
...
@@ -318,8 +332,10 @@ class ChatTable:
def
delete_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
def
delete_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
try
:
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
)
&
(
Chat
.
user_id
==
user_id
))
with
get_db
()
as
db
:
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
db
.
commit
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
except
:
...
@@ -328,10 +344,12 @@ class ChatTable:
...
@@ -328,10 +344,12 @@ class ChatTable:
def
delete_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
def
delete_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
try
:
with
get_db
()
as
db
:
self
.
delete_shared_chats_by_user_id
(
user_id
)
self
.
delete_shared_chats_by_user_id
(
user_id
)
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
==
user_id
)
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
)
.
delete
()
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
commit
()
return
True
return
True
except
:
except
:
...
@@ -339,17 +357,18 @@ class ChatTable:
...
@@ -339,17 +357,18 @@ class ChatTable:
def
delete_shared_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
def
delete_shared_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
try
:
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
Chat
.
select
().
where
(
Chat
.
user_id
==
user_id
)
]
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
<<
shared_chat_ids
)
with
get_db
()
as
db
:
query
.
execute
()
# Remove the rows, return number of rows removed.
chats_by_user
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
all
()
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
chats_by_user
]
db
.
query
(
Chat
).
filter
(
Chat
.
user_id
.
in_
(
shared_chat_ids
)).
delete
()
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
Chats
=
ChatTable
(
DB
)
Chats
=
ChatTable
()
backend/apps/webui/models/documents.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
typing
import
List
,
Optional
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
import
time
import
time
import
logging
import
logging
from
utils.utils
import
decode_token
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
utils.misc
import
get_gravatar_url
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
get_db
import
json
import
json
...
@@ -22,20 +19,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
...
@@ -22,20 +19,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
####################
class
Document
(
Model
):
class
Document
(
Base
):
collection_name
=
CharField
(
unique
=
True
)
__tablename__
=
"document"
name
=
CharField
(
unique
=
True
)
title
=
TextField
()
filename
=
TextField
()
content
=
TextField
(
null
=
True
)
user_id
=
CharField
()
timestamp
=
BigIntegerField
()
class
Meta
:
collection_name
=
Column
(
String
,
primary_key
=
True
)
database
=
DB
name
=
Column
(
String
,
unique
=
True
)
title
=
Column
(
Text
)
filename
=
Column
(
Text
)
content
=
Column
(
Text
,
nullable
=
True
)
user_id
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
DocumentModel
(
BaseModel
):
class
DocumentModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
collection_name
:
str
collection_name
:
str
name
:
str
name
:
str
title
:
str
title
:
str
...
@@ -72,13 +70,12 @@ class DocumentForm(DocumentUpdateForm):
...
@@ -72,13 +70,12 @@ class DocumentForm(DocumentUpdateForm):
class
DocumentsTable
:
class
DocumentsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Document
])
def
insert_new_doc
(
def
insert_new_doc
(
self
,
user_id
:
str
,
form_data
:
DocumentForm
self
,
user_id
:
str
,
form_data
:
DocumentForm
)
->
Optional
[
DocumentModel
]:
)
->
Optional
[
DocumentModel
]:
with
get_db
()
as
db
:
document
=
DocumentModel
(
document
=
DocumentModel
(
**
{
**
{
**
form_data
.
model_dump
(),
**
form_data
.
model_dump
(),
...
@@ -88,9 +85,12 @@ class DocumentsTable:
...
@@ -88,9 +85,12 @@ class DocumentsTable:
)
)
try
:
try
:
result
=
Document
.
create
(
**
document
.
model_dump
())
result
=
Document
(
**
document
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
d
ocument
return
D
ocument
Model
.
model_validate
(
result
)
else
:
else
:
return
None
return
None
except
:
except
:
...
@@ -98,31 +98,35 @@ class DocumentsTable:
...
@@ -98,31 +98,35 @@ class DocumentsTable:
def
get_doc_by_name
(
self
,
name
:
str
)
->
Optional
[
DocumentModel
]:
def
get_doc_by_name
(
self
,
name
:
str
)
->
Optional
[
DocumentModel
]:
try
:
try
:
document
=
Document
.
get
(
Document
.
name
==
name
)
with
get_db
()
as
db
:
return
DocumentModel
(
**
model_to_dict
(
document
))
document
=
db
.
query
(
Document
).
filter_by
(
name
=
name
).
first
()
return
DocumentModel
.
model_validate
(
document
)
if
document
else
None
except
:
except
:
return
None
return
None
def
get_docs
(
self
)
->
List
[
DocumentModel
]:
def
get_docs
(
self
)
->
List
[
DocumentModel
]:
with
get_db
()
as
db
:
return
[
return
[
DocumentModel
(
**
model_to_dict
(
doc
))
DocumentModel
.
model_validate
(
doc
)
for
doc
in
db
.
query
(
Document
).
all
()
for
doc
in
Document
.
select
()
# .limit(limit).offset(skip)
]
]
def
update_doc_by_name
(
def
update_doc_by_name
(
self
,
name
:
str
,
form_data
:
DocumentUpdateForm
self
,
name
:
str
,
form_data
:
DocumentUpdateForm
)
->
Optional
[
DocumentModel
]:
)
->
Optional
[
DocumentModel
]:
try
:
try
:
query
=
Document
.
update
(
with
get_db
()
as
db
:
title
=
form_data
.
title
,
name
=
form_data
.
name
,
db
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
timestamp
=
int
(
time
.
time
()),
{
).
where
(
Document
.
name
==
name
)
"title"
:
form_data
.
title
,
query
.
execute
()
"name"
:
form_data
.
name
,
"timestamp"
:
int
(
time
.
time
()),
doc
=
Document
.
get
(
Document
.
name
==
form_data
.
name
)
}
return
DocumentModel
(
**
model_to_dict
(
doc
))
)
db
.
commit
()
return
self
.
get_doc_by_name
(
form_data
.
name
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
return
None
return
None
...
@@ -135,26 +139,29 @@ class DocumentsTable:
...
@@ -135,26 +139,29 @@ class DocumentsTable:
doc_content
=
json
.
loads
(
doc
.
content
if
doc
.
content
else
"{}"
)
doc_content
=
json
.
loads
(
doc
.
content
if
doc
.
content
else
"{}"
)
doc_content
=
{
**
doc_content
,
**
updated
}
doc_content
=
{
**
doc_content
,
**
updated
}
query
=
Document
.
update
(
with
get_db
()
as
db
:
content
=
json
.
dumps
(
doc_content
),
timestamp
=
int
(
time
.
time
()),
).
where
(
Document
.
name
==
name
)
query
.
execute
()
doc
=
Document
.
get
(
Document
.
name
==
name
)
db
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
return
DocumentModel
(
**
model_to_dict
(
doc
))
{
"content"
:
json
.
dumps
(
doc_content
),
"timestamp"
:
int
(
time
.
time
()),
}
)
db
.
commit
()
return
self
.
get_doc_by_name
(
name
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
return
None
return
None
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
try
:
try
:
query
=
Document
.
delete
().
where
((
Document
.
name
==
name
))
with
get_db
()
as
db
:
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Document
).
filter_by
(
name
=
name
).
delete
()
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
Documents
=
DocumentsTable
(
DB
)
Documents
=
DocumentsTable
()
backend/apps/webui/models/files.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
import
time
import
time
import
logging
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Text
from
apps.webui.internal.db
import
JSONField
,
Base
,
get_db
import
json
import
json
...
@@ -18,15 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
...
@@ -18,15 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
####################
class
File
(
Model
):
class
File
(
Base
):
id
=
CharField
(
unique
=
True
)
__tablename__
=
"file"
user_id
=
CharField
()
filename
=
TextField
()
meta
=
JSONField
()
created_at
=
BigIntegerField
()
class
Meta
:
id
=
Column
(
String
,
primary_key
=
True
)
database
=
DB
user_id
=
Column
(
String
)
filename
=
Column
(
Text
)
meta
=
Column
(
JSONField
)
created_at
=
Column
(
BigInteger
)
class
FileModel
(
BaseModel
):
class
FileModel
(
BaseModel
):
...
@@ -36,6 +36,8 @@ class FileModel(BaseModel):
...
@@ -36,6 +36,8 @@ class FileModel(BaseModel):
meta
:
dict
meta
:
dict
created_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
####################
# Forms
# Forms
...
@@ -57,11 +59,10 @@ class FileForm(BaseModel):
...
@@ -57,11 +59,10 @@ class FileForm(BaseModel):
class
FilesTable
:
class
FilesTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
File
])
def
insert_new_file
(
self
,
user_id
:
str
,
form_data
:
FileForm
)
->
Optional
[
FileModel
]:
def
insert_new_file
(
self
,
user_id
:
str
,
form_data
:
FileForm
)
->
Optional
[
FileModel
]:
with
get_db
()
as
db
:
file
=
FileModel
(
file
=
FileModel
(
**
{
**
{
**
form_data
.
model_dump
(),
**
form_data
.
model_dump
(),
...
@@ -71,9 +72,12 @@ class FilesTable:
...
@@ -71,9 +72,12 @@ class FilesTable:
)
)
try
:
try
:
result
=
File
.
create
(
**
file
.
model_dump
())
result
=
File
(
**
file
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
f
ile
return
F
ile
Model
.
model_validate
(
result
)
else
:
else
:
return
None
return
None
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -81,32 +85,42 @@ class FilesTable:
...
@@ -81,32 +85,42 @@ class FilesTable:
return
None
return
None
def
get_file_by_id
(
self
,
id
:
str
)
->
Optional
[
FileModel
]:
def
get_file_by_id
(
self
,
id
:
str
)
->
Optional
[
FileModel
]:
with
get_db
()
as
db
:
try
:
try
:
file
=
File
.
get
(
File
.
id
==
id
)
file
=
db
.
get
(
File
,
id
)
return
FileModel
(
**
model_
to_dict
(
file
)
)
return
FileModel
.
model_
validate
(
file
)
except
:
except
:
return
None
return
None
def
get_files
(
self
)
->
List
[
FileModel
]:
def
get_files
(
self
)
->
List
[
FileModel
]:
return
[
FileModel
(
**
model_to_dict
(
file
))
for
file
in
File
.
select
()]
with
get_db
()
as
db
:
return
[
FileModel
.
model_validate
(
file
)
for
file
in
db
.
query
(
File
).
all
()]
def
delete_file_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_file_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
try
:
query
=
File
.
delete
().
where
((
File
.
id
==
id
)
)
db
.
query
(
File
).
filter_by
(
id
=
id
).
delete
(
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
def
delete_all_files
(
self
)
->
bool
:
def
delete_all_files
(
self
)
->
bool
:
with
get_db
()
as
db
:
try
:
try
:
query
=
File
.
delete
()
db
.
query
(
File
)
.
delete
()
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
Files
=
FilesTable
(
DB
)
Files
=
FilesTable
()
backend/apps/webui/models/functions.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
import
time
import
time
import
logging
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
Column
,
String
,
Text
,
BigInteger
,
Boolean
from
apps.webui.internal.db
import
JSONField
,
Base
,
get_db
from
apps.webui.models.users
import
Users
from
apps.webui.models.users
import
Users
import
json
import
json
...
@@ -21,21 +22,20 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
...
@@ -21,21 +22,20 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
####################
class
Function
(
Model
):
class
Function
(
Base
):
id
=
CharField
(
unique
=
True
)
__tablename__
=
"function"
user_id
=
CharField
()
name
=
TextField
()
type
=
TextField
()
content
=
TextField
()
meta
=
JSONField
()
valves
=
JSONField
()
is_active
=
BooleanField
(
default
=
False
)
is_global
=
BooleanField
(
default
=
False
)
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Meta
:
id
=
Column
(
String
,
primary_key
=
True
)
database
=
DB
user_id
=
Column
(
String
)
name
=
Column
(
Text
)
type
=
Column
(
Text
)
content
=
Column
(
Text
)
meta
=
Column
(
JSONField
)
valves
=
Column
(
JSONField
)
is_active
=
Column
(
Boolean
)
is_global
=
Column
(
Boolean
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
FunctionMeta
(
BaseModel
):
class
FunctionMeta
(
BaseModel
):
...
@@ -55,6 +55,8 @@ class FunctionModel(BaseModel):
...
@@ -55,6 +55,8 @@ class FunctionModel(BaseModel):
updated_at
:
int
# timestamp in epoch
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
####################
# Forms
# Forms
...
@@ -85,13 +87,11 @@ class FunctionValves(BaseModel):
...
@@ -85,13 +87,11 @@ class FunctionValves(BaseModel):
class
FunctionsTable
:
class
FunctionsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Function
])
def
insert_new_function
(
def
insert_new_function
(
self
,
user_id
:
str
,
type
:
str
,
form_data
:
FunctionForm
self
,
user_id
:
str
,
type
:
str
,
form_data
:
FunctionForm
)
->
Optional
[
FunctionModel
]:
)
->
Optional
[
FunctionModel
]:
function
=
FunctionModel
(
function
=
FunctionModel
(
**
{
**
{
**
form_data
.
model_dump
(),
**
form_data
.
model_dump
(),
...
@@ -103,9 +103,13 @@ class FunctionsTable:
...
@@ -103,9 +103,13 @@ class FunctionsTable:
)
)
try
:
try
:
result
=
Function
.
create
(
**
function
.
model_dump
())
with
get_db
()
as
db
:
result
=
Function
(
**
function
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
f
unction
return
F
unction
Model
.
model_validate
(
result
)
else
:
else
:
return
None
return
None
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -114,52 +118,60 @@ class FunctionsTable:
...
@@ -114,52 +118,60 @@ class FunctionsTable:
def
get_function_by_id
(
self
,
id
:
str
)
->
Optional
[
FunctionModel
]:
def
get_function_by_id
(
self
,
id
:
str
)
->
Optional
[
FunctionModel
]:
try
:
try
:
function
=
Function
.
get
(
Function
.
id
==
id
)
with
get_db
()
as
db
:
return
FunctionModel
(
**
model_to_dict
(
function
))
function
=
db
.
get
(
Function
,
id
)
return
FunctionModel
.
model_validate
(
function
)
except
:
except
:
return
None
return
None
def
get_functions
(
self
,
active_only
=
False
)
->
List
[
FunctionModel
]:
def
get_functions
(
self
,
active_only
=
False
)
->
List
[
FunctionModel
]:
with
get_db
()
as
db
:
if
active_only
:
if
active_only
:
return
[
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
Function
.
select
().
where
(
Function
.
is_active
==
True
)
for
function
in
db
.
query
(
Function
).
filter_by
(
is_active
=
True
)
.
all
()
]
]
else
:
else
:
return
[
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
Function
.
select
()
for
function
in
db
.
query
(
Function
).
all
()
]
]
def
get_functions_by_type
(
def
get_functions_by_type
(
self
,
type
:
str
,
active_only
=
False
self
,
type
:
str
,
active_only
=
False
)
->
List
[
FunctionModel
]:
)
->
List
[
FunctionModel
]:
with
get_db
()
as
db
:
if
active_only
:
if
active_only
:
return
[
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
Function
.
select
().
where
(
for
function
in
db
.
query
(
Function
)
Function
.
type
==
type
,
Function
.
is_active
==
True
.
filter_by
(
type
=
type
,
is_active
=
True
)
)
.
all
(
)
]
]
else
:
else
:
return
[
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
Function
.
select
().
where
(
Function
.
type
==
type
)
for
function
in
db
.
query
(
Function
).
filter_by
(
type
=
type
).
all
(
)
]
]
def
get_global_filter_functions
(
self
)
->
List
[
FunctionModel
]:
def
get_global_filter_functions
(
self
)
->
List
[
FunctionModel
]:
with
get_db
()
as
db
:
return
[
return
[
FunctionModel
(
**
model_to_dict
(
function
))
FunctionModel
.
model_validate
(
function
)
for
function
in
Function
.
select
().
where
(
for
function
in
db
.
query
(
Function
)
Function
.
type
==
"filter"
,
.
filter_by
(
type
=
"filter"
,
is_active
=
True
,
is_global
=
True
)
Function
.
is_active
==
True
,
.
all
()
Function
.
is_global
==
True
,
)
]
]
def
get_function_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
def
get_function_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
with
get_db
()
as
db
:
try
:
try
:
function
=
Function
.
get
(
Function
.
id
==
id
)
function
=
db
.
get
(
Function
,
id
)
return
function
.
valves
if
function
.
valves
else
{}
return
function
.
valves
if
function
.
valves
else
{}
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
print
(
f
"An error occurred:
{
e
}
"
)
...
@@ -168,24 +180,25 @@ class FunctionsTable:
...
@@ -168,24 +180,25 @@ class FunctionsTable:
def
update_function_valves_by_id
(
def
update_function_valves_by_id
(
self
,
id
:
str
,
valves
:
dict
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
FunctionValves
]:
)
->
Optional
[
FunctionValves
]:
with
get_db
()
as
db
:
try
:
try
:
query
=
Function
.
update
(
function
=
db
.
get
(
Function
,
id
)
**
{
"valves"
:
valves
},
function
.
valves
=
valves
updated_at
=
int
(
time
.
time
()),
function
.
updated_at
=
int
(
time
.
time
())
).
where
(
Function
.
id
==
id
)
db
.
commit
()
query
.
execute
()
db
.
refresh
(
function
)
return
self
.
get_function_by_id
(
id
)
function
=
Function
.
get
(
Function
.
id
==
id
)
return
FunctionValves
(
**
model_to_dict
(
function
))
except
:
except
:
return
None
return
None
def
get_user_valves_by_id_and_user_id
(
def
get_user_valves_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
dict
]:
)
->
Optional
[
dict
]:
try
:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "functions" and "valves" settings
# Check if user has "functions" and "valves" settings
if
"functions"
not
in
user_settings
:
if
"functions"
not
in
user_settings
:
...
@@ -201,9 +214,10 @@ class FunctionsTable:
...
@@ -201,9 +214,10 @@ class FunctionsTable:
def
update_user_valves_by_id_and_user_id
(
def
update_user_valves_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
,
valves
:
dict
self
,
id
:
str
,
user_id
:
str
,
valves
:
dict
)
->
Optional
[
dict
]:
)
->
Optional
[
dict
]:
try
:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "functions" and "valves" settings
# Check if user has "functions" and "valves" settings
if
"functions"
not
in
user_settings
:
if
"functions"
not
in
user_settings
:
...
@@ -222,39 +236,44 @@ class FunctionsTable:
...
@@ -222,39 +236,44 @@ class FunctionsTable:
return
None
return
None
def
update_function_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
FunctionModel
]:
def
update_function_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
FunctionModel
]:
with
get_db
()
as
db
:
try
:
try
:
query
=
Function
.
update
(
db
.
query
(
Function
).
filter_by
(
id
=
id
).
update
(
{
**
updated
,
**
updated
,
updated_at
=
int
(
time
.
time
()),
"updated_at"
:
int
(
time
.
time
()),
).
where
(
Function
.
id
==
id
)
}
query
.
execute
()
)
db
.
commit
()
function
=
Function
.
get
(
Function
.
id
==
id
)
return
self
.
get_function_by_id
(
id
)
return
FunctionModel
(
**
model_to_dict
(
function
))
except
:
except
:
return
None
return
None
def
deactivate_all_functions
(
self
)
->
Optional
[
bool
]:
def
deactivate_all_functions
(
self
)
->
Optional
[
bool
]:
with
get_db
()
as
db
:
try
:
try
:
query
=
Function
.
update
(
db
.
query
(
Function
).
update
(
**
{
"is_active"
:
False
},
{
updated_at
=
int
(
time
.
time
()),
"is_active"
:
False
,
"updated_at"
:
int
(
time
.
time
()),
}
)
)
db
.
commit
()
query
.
execute
()
return
True
return
True
except
:
except
:
return
None
return
None
def
delete_function_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_function_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
try
:
query
=
Function
.
delete
().
where
((
Function
.
id
==
id
)
)
db
.
query
(
Function
).
filter_by
(
id
=
id
).
delete
(
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
Functions
=
FunctionsTable
(
DB
)
Functions
=
FunctionsTable
()
backend/apps/webui/models/memories.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Text
from
apps.webui.models.chats
import
Chats
from
apps.webui.internal.db
import
Base
,
get_db
import
time
import
time
import
uuid
import
uuid
...
@@ -14,15 +13,14 @@ import uuid
...
@@ -14,15 +13,14 @@ import uuid
####################
####################
class
Memory
(
Model
):
class
Memory
(
Base
):
id
=
CharField
(
unique
=
True
)
__tablename__
=
"memory"
user_id
=
CharField
()
content
=
TextField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Meta
:
id
=
Column
(
String
,
primary_key
=
True
)
database
=
DB
user_id
=
Column
(
String
)
content
=
Column
(
Text
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
MemoryModel
(
BaseModel
):
class
MemoryModel
(
BaseModel
):
...
@@ -32,6 +30,8 @@ class MemoryModel(BaseModel):
...
@@ -32,6 +30,8 @@ class MemoryModel(BaseModel):
updated_at
:
int
# timestamp in epoch
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
####################
# Forms
# Forms
...
@@ -39,15 +39,14 @@ class MemoryModel(BaseModel):
...
@@ -39,15 +39,14 @@ class MemoryModel(BaseModel):
class
MemoriesTable
:
class
MemoriesTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Memory
])
def
insert_new_memory
(
def
insert_new_memory
(
self
,
self
,
user_id
:
str
,
user_id
:
str
,
content
:
str
,
content
:
str
,
)
->
Optional
[
MemoryModel
]:
)
->
Optional
[
MemoryModel
]:
with
get_db
()
as
db
:
id
=
str
(
uuid
.
uuid4
())
id
=
str
(
uuid
.
uuid4
())
memory
=
MemoryModel
(
memory
=
MemoryModel
(
...
@@ -59,9 +58,12 @@ class MemoriesTable:
...
@@ -59,9 +58,12 @@ class MemoriesTable:
"updated_at"
:
int
(
time
.
time
()),
"updated_at"
:
int
(
time
.
time
()),
}
}
)
)
result
=
Memory
.
create
(
**
memory
.
model_dump
())
result
=
Memory
(
**
memory
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
m
emory
return
M
emory
Model
.
model_validate
(
result
)
else
:
else
:
return
None
return
None
...
@@ -70,40 +72,50 @@ class MemoriesTable:
...
@@ -70,40 +72,50 @@ class MemoriesTable:
id
:
str
,
id
:
str
,
content
:
str
,
content
:
str
,
)
->
Optional
[
MemoryModel
]:
)
->
Optional
[
MemoryModel
]:
with
get_db
()
as
db
:
try
:
try
:
memory
=
Memory
.
get
(
Memory
.
id
==
id
)
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
update
(
memory
.
content
=
content
{
"
content
"
:
content
,
"updated_at"
:
int
(
time
.
time
())}
memory
.
updated_at
=
int
(
time
.
time
()
)
)
memory
.
save
()
db
.
commit
()
return
MemoryModel
(
**
model_to_dict
(
memory
)
)
return
self
.
get_memory_by_id
(
id
)
except
:
except
:
return
None
return
None
def
get_memories
(
self
)
->
List
[
MemoryModel
]:
def
get_memories
(
self
)
->
List
[
MemoryModel
]:
with
get_db
()
as
db
:
try
:
try
:
memories
=
Memory
.
select
()
memories
=
db
.
query
(
Memory
).
all
()
return
[
MemoryModel
(
**
model_
to_dict
(
memory
)
)
for
memory
in
memories
]
return
[
MemoryModel
.
model_
validate
(
memory
)
for
memory
in
memories
]
except
:
except
:
return
None
return
None
def
get_memories_by_user_id
(
self
,
user_id
:
str
)
->
List
[
MemoryModel
]:
def
get_memories_by_user_id
(
self
,
user_id
:
str
)
->
List
[
MemoryModel
]:
with
get_db
()
as
db
:
try
:
try
:
memories
=
Memory
.
select
().
where
(
Memory
.
user_id
==
user_id
)
memories
=
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
)
.
all
()
return
[
MemoryModel
(
**
model_
to_dict
(
memory
)
)
for
memory
in
memories
]
return
[
MemoryModel
.
model_
validate
(
memory
)
for
memory
in
memories
]
except
:
except
:
return
None
return
None
def
get_memory_by_id
(
self
,
id
)
->
Optional
[
MemoryModel
]:
def
get_memory_by_id
(
self
,
id
:
str
)
->
Optional
[
MemoryModel
]:
with
get_db
()
as
db
:
try
:
try
:
memory
=
Memory
.
get
(
Memory
.
id
==
id
)
memory
=
db
.
get
(
Memory
,
id
)
return
MemoryModel
(
**
model_
to_dict
(
memory
)
)
return
MemoryModel
.
model_
validate
(
memory
)
except
:
except
:
return
None
return
None
def
delete_memory_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_memory_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
)
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
delete
(
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
commit
()
return
True
return
True
...
@@ -111,22 +123,26 @@ class MemoriesTable:
...
@@ -111,22 +123,26 @@ class MemoriesTable:
return
False
return
False
def
delete_memories_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
def
delete_memories_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
user_id
==
user_id
)
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
)
.
delete
()
query
.
execute
()
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
def
delete_memory_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
def
delete_memory_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
with
get_db
()
as
db
:
try
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
,
Memory
.
user_id
==
user_id
)
db
.
query
(
Memory
).
filter_by
(
id
=
id
,
user_id
=
user_id
)
.
delete
()
query
.
execute
()
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
Memories
=
MemoriesTable
(
DB
)
Memories
=
MemoriesTable
()
backend/apps/webui/models/models.py
View file @
f9e3c47d
...
@@ -2,13 +2,10 @@ import json
...
@@ -2,13 +2,10 @@ import json
import
logging
import
logging
from
typing
import
Optional
from
typing
import
Optional
import
peewee
as
pw
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
,
get_db
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
from
config
import
SRC_LOG_LEVELS
...
@@ -32,7 +29,7 @@ class ModelParams(BaseModel):
...
@@ -32,7 +29,7 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table
# ModelMeta is a model for the data stored in the meta field of the Model table
class
ModelMeta
(
BaseModel
):
class
ModelMeta
(
BaseModel
):
profile_image_url
:
Optional
[
str
]
=
"/favicon.png"
profile_image_url
:
Optional
[
str
]
=
"/
static/
favicon.png"
description
:
Optional
[
str
]
=
None
description
:
Optional
[
str
]
=
None
"""
"""
...
@@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
...
@@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
pass
pass
class
Model
(
pw
.
Model
):
class
Model
(
Base
):
id
=
pw
.
TextField
(
unique
=
True
)
__tablename__
=
"model"
id
=
Column
(
Text
,
primary_key
=
True
)
"""
"""
The model's id as used in the API. If set to an existing model, it will override the model.
The model's id as used in the API. If set to an existing model, it will override the model.
"""
"""
user_id
=
pw
.
TextField
(
)
user_id
=
Column
(
Text
)
base_model_id
=
pw
.
TextField
(
null
=
True
)
base_model_id
=
Column
(
Text
,
nullable
=
True
)
"""
"""
An optional pointer to the actual model that should be used when proxying requests.
An optional pointer to the actual model that should be used when proxying requests.
"""
"""
name
=
pw
.
TextField
(
)
name
=
Column
(
Text
)
"""
"""
The human-readable display name of the model.
The human-readable display name of the model.
"""
"""
params
=
JSONField
(
)
params
=
Column
(
JSONField
)
"""
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
"""
meta
=
JSONField
(
)
meta
=
Column
(
JSONField
)
"""
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
"""
updated_at
=
BigIntegerField
()
updated_at
=
Column
(
BigInteger
)
created_at
=
BigIntegerField
()
created_at
=
Column
(
BigInteger
)
class
Meta
:
database
=
DB
class
ModelModel
(
BaseModel
):
class
ModelModel
(
BaseModel
):
...
@@ -92,6 +88,8 @@ class ModelModel(BaseModel):
...
@@ -92,6 +88,8 @@ class ModelModel(BaseModel):
updated_at
:
int
# timestamp in epoch
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
####################
# Forms
# Forms
...
@@ -115,12 +113,6 @@ class ModelForm(BaseModel):
...
@@ -115,12 +113,6 @@ class ModelForm(BaseModel):
class
ModelsTable
:
class
ModelsTable
:
def
__init__
(
self
,
db
:
pw
.
SqliteDatabase
|
pw
.
PostgresqlDatabase
,
):
self
.
db
=
db
self
.
db
.
create_tables
([
Model
])
def
insert_new_model
(
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
self
,
form_data
:
ModelForm
,
user_id
:
str
...
@@ -134,10 +126,16 @@ class ModelsTable:
...
@@ -134,10 +126,16 @@ class ModelsTable:
}
}
)
)
try
:
try
:
result
=
Model
.
create
(
**
model
.
model_dump
())
with
get_db
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
model
return
ModelModel
.
model_validate
(
result
)
else
:
else
:
return
None
return
None
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -145,23 +143,33 @@ class ModelsTable:
...
@@ -145,23 +143,33 @@ class ModelsTable:
return
None
return
None
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
return
[
ModelModel
(
**
model_to_dict
(
model
))
for
model
in
Model
.
select
()]
with
get_db
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
try
:
model
=
Model
.
get
(
Model
.
id
==
id
)
with
get_db
()
as
db
:
return
ModelModel
(
**
model_to_dict
(
model
))
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
except
:
except
:
return
None
return
None
def
update_model_by_id
(
self
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
def
update_model_by_id
(
self
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
try
:
try
:
with
get_db
()
as
db
:
# update only the fields that are present in the model
# update only the fields that are present in the model
query
=
Model
.
update
(
**
model
.
model_dump
()).
where
(
Model
.
id
==
id
)
result
=
(
query
.
execute
()
db
.
query
(
Model
)
.
filter_by
(
id
=
id
)
.
update
(
model
.
model_dump
(
exclude
=
{
"id"
},
exclude_none
=
True
))
)
db
.
commit
()
model
=
Model
.
get
(
Model
.
id
==
id
)
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
(
**
model_to_dict
(
model
))
db
.
refresh
(
model
)
return
ModelModel
.
model_validate
(
model
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
...
@@ -169,11 +177,14 @@ class ModelsTable:
...
@@ -169,11 +177,14 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
try
:
query
=
Model
.
delete
().
where
(
Model
.
id
==
id
)
with
get_db
()
as
db
:
query
.
execute
()
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
Models
=
ModelsTable
(
DB
)
Models
=
ModelsTable
()
backend/apps/webui/models/prompts.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
typing
import
List
,
Optional
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
import
time
import
time
from
utils.utils
import
decode_token
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
utils.misc
import
get_gravatar_url
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
get_db
import
json
import
json
...
@@ -16,15 +13,14 @@ import json
...
@@ -16,15 +13,14 @@ import json
####################
####################
class
Prompt
(
Model
):
class
Prompt
(
Base
):
command
=
CharField
(
unique
=
True
)
__tablename__
=
"prompt"
user_id
=
CharField
()
title
=
TextField
()
content
=
TextField
()
timestamp
=
BigIntegerField
()
class
Meta
:
command
=
Column
(
String
,
primary_key
=
True
)
database
=
DB
user_id
=
Column
(
String
)
title
=
Column
(
Text
)
content
=
Column
(
Text
)
timestamp
=
Column
(
BigInteger
)
class
PromptModel
(
BaseModel
):
class
PromptModel
(
BaseModel
):
...
@@ -34,6 +30,8 @@ class PromptModel(BaseModel):
...
@@ -34,6 +30,8 @@ class PromptModel(BaseModel):
content
:
str
content
:
str
timestamp
:
int
# timestamp in epoch
timestamp
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
####################
# Forms
# Forms
...
@@ -48,10 +46,6 @@ class PromptForm(BaseModel):
...
@@ -48,10 +46,6 @@ class PromptForm(BaseModel):
class
PromptsTable
:
class
PromptsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Prompt
])
def
insert_new_prompt
(
def
insert_new_prompt
(
self
,
user_id
:
str
,
form_data
:
PromptForm
self
,
user_id
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
)
->
Optional
[
PromptModel
]:
...
@@ -66,53 +60,60 @@ class PromptsTable:
...
@@ -66,53 +60,60 @@ class PromptsTable:
)
)
try
:
try
:
result
=
Prompt
.
create
(
**
prompt
.
model_dump
())
with
get_db
()
as
db
:
result
=
Prompt
(
**
prompt
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
p
rompt
return
P
rompt
Model
.
model_validate
(
result
)
else
:
else
:
return
None
return
None
except
:
except
Exception
as
e
:
return
None
return
None
def
get_prompt_by_command
(
self
,
command
:
str
)
->
Optional
[
PromptModel
]:
def
get_prompt_by_command
(
self
,
command
:
str
)
->
Optional
[
PromptModel
]:
try
:
try
:
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
with
get_db
()
as
db
:
return
PromptModel
(
**
model_to_dict
(
prompt
))
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
return
PromptModel
.
model_validate
(
prompt
)
except
:
except
:
return
None
return
None
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
with
get_db
()
as
db
:
return
[
return
[
PromptModel
(
**
model_to_dict
(
prompt
))
PromptModel
.
model_validate
(
prompt
)
for
prompt
in
db
.
query
(
Prompt
).
all
()
for
prompt
in
Prompt
.
select
()
# .limit(limit).offset(skip)
]
]
def
update_prompt_by_command
(
def
update_prompt_by_command
(
self
,
command
:
str
,
form_data
:
PromptForm
self
,
command
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
)
->
Optional
[
PromptModel
]:
try
:
try
:
query
=
Prompt
.
update
(
with
get_db
()
as
db
:
title
=
form_data
.
title
,
content
=
form_data
.
content
,
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
timestamp
=
int
(
time
.
time
()),
prompt
.
title
=
form_data
.
title
).
where
(
Prompt
.
command
==
command
)
prompt
.
content
=
form_data
.
content
prompt
.
timestamp
=
int
(
time
.
time
())
query
.
execute
()
db
.
commit
()
return
PromptModel
.
model_validate
(
prompt
)
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
return
PromptModel
(
**
model_to_dict
(
prompt
))
except
:
except
:
return
None
return
None
def
delete_prompt_by_command
(
self
,
command
:
str
)
->
bool
:
def
delete_prompt_by_command
(
self
,
command
:
str
)
->
bool
:
try
:
try
:
query
=
Prompt
.
delete
().
where
((
Prompt
.
command
==
command
))
with
get_db
()
as
db
:
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
delete
()
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
Prompts
=
PromptsTable
(
DB
)
Prompts
=
PromptsTable
()
backend/apps/webui/models/tags.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Optional
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
import
json
import
json
import
uuid
import
uuid
import
time
import
time
import
logging
import
logging
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
get_db
from
config
import
SRC_LOG_LEVELS
from
config
import
SRC_LOG_LEVELS
...
@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
...
@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
####################
class
Tag
(
Model
):
class
Tag
(
Base
):
id
=
CharField
(
unique
=
True
)
__tablename__
=
"tag"
name
=
CharField
()
user_id
=
CharField
()
data
=
TextField
(
null
=
True
)
class
Meta
:
id
=
Column
(
String
,
primary_key
=
True
)
database
=
DB
name
=
Column
(
String
)
user_id
=
Column
(
String
)
data
=
Column
(
Text
,
nullable
=
True
)
class
ChatIdTag
(
Model
):
class
ChatIdTag
(
Base
):
id
=
CharField
(
unique
=
True
)
__tablename__
=
"chatidtag"
tag_name
=
CharField
()
chat_id
=
CharField
()
user_id
=
CharField
()
timestamp
=
BigIntegerField
()
class
Meta
:
id
=
Column
(
String
,
primary_key
=
True
)
database
=
DB
tag_name
=
Column
(
String
)
chat_id
=
Column
(
String
)
user_id
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
TagModel
(
BaseModel
):
class
TagModel
(
BaseModel
):
...
@@ -47,6 +45,8 @@ class TagModel(BaseModel):
...
@@ -47,6 +45,8 @@ class TagModel(BaseModel):
user_id
:
str
user_id
:
str
data
:
Optional
[
str
]
=
None
data
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
class
ChatIdTagModel
(
BaseModel
):
class
ChatIdTagModel
(
BaseModel
):
id
:
str
id
:
str
...
@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
...
@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
user_id
:
str
user_id
:
str
timestamp
:
int
timestamp
:
int
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
####################
# Forms
# Forms
...
@@ -75,17 +77,19 @@ class ChatTagsResponse(BaseModel):
...
@@ -75,17 +77,19 @@ class ChatTagsResponse(BaseModel):
class
TagTable
:
class
TagTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
db
.
create_tables
([
Tag
,
ChatIdTag
])
def
insert_new_tag
(
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
def
insert_new_tag
(
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
with
get_db
()
as
db
:
id
=
str
(
uuid
.
uuid4
())
id
=
str
(
uuid
.
uuid4
())
tag
=
TagModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"name"
:
name
})
tag
=
TagModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"name"
:
name
})
try
:
try
:
result
=
Tag
.
create
(
**
tag
.
model_dump
())
result
=
Tag
(
**
tag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
t
ag
return
T
ag
Model
.
model_validate
(
result
)
else
:
else
:
return
None
return
None
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -95,8 +99,9 @@ class TagTable:
...
@@ -95,8 +99,9 @@ class TagTable:
self
,
name
:
str
,
user_id
:
str
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
)
->
Optional
[
TagModel
]:
try
:
try
:
tag
=
Tag
.
get
(
Tag
.
name
==
name
,
Tag
.
user_id
==
user_id
)
with
get_db
()
as
db
:
return
TagModel
(
**
model_to_dict
(
tag
))
tag
=
db
.
query
(
Tag
).
filter
(
name
=
name
,
user_id
=
user_id
).
first
()
return
TagModel
.
model_validate
(
tag
)
except
Exception
as
e
:
except
Exception
as
e
:
return
None
return
None
...
@@ -118,81 +123,109 @@ class TagTable:
...
@@ -118,81 +123,109 @@ class TagTable:
}
}
)
)
try
:
try
:
result
=
ChatIdTag
.
create
(
**
chatIdTag
.
model_dump
())
with
get_db
()
as
db
:
result
=
ChatIdTag
(
**
chatIdTag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
c
hatIdTag
return
C
hatIdTag
Model
.
model_validate
(
result
)
else
:
else
:
return
None
return
None
except
:
except
:
return
None
return
None
def
get_tags_by_user_id
(
self
,
user_id
:
str
)
->
List
[
TagModel
]:
def
get_tags_by_user_id
(
self
,
user_id
:
str
)
->
List
[
TagModel
]:
with
get_db
()
as
db
:
tag_names
=
[
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
chat_id_tag
.
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
for
chat_id_tag
in
(
.
where
(
ChatIdTag
.
user_id
==
user_id
)
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
]
return
[
return
[
TagModel
(
**
model_to_dict
(
tag
))
TagModel
.
model_validate
(
tag
)
for
tag
in
Tag
.
select
()
for
tag
in
(
.
where
(
Tag
.
user_id
==
user_id
)
db
.
query
(
Tag
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
]
def
get_tags_by_chat_id_and_user_id
(
def
get_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
self
,
chat_id
:
str
,
user_id
:
str
)
->
List
[
TagModel
]:
)
->
List
[
TagModel
]:
with
get_db
()
as
db
:
tag_names
=
[
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
chat_id_tag
.
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
for
chat_id_tag
in
(
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
chat_id
==
chat_id
))
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
chat_id
=
chat_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
]
return
[
return
[
TagModel
(
**
model_to_dict
(
tag
))
TagModel
.
model_validate
(
tag
)
for
tag
in
Tag
.
select
()
for
tag
in
(
.
where
(
Tag
.
user_id
==
user_id
)
db
.
query
(
Tag
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
]
def
get_chat_ids_by_tag_name_and_user_id
(
def
get_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
self
,
tag_name
:
str
,
user_id
:
str
)
->
Optional
[
ChatIdTagModel
]:
)
->
List
[
ChatIdTagModel
]:
with
get_db
()
as
db
:
return
[
return
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
))
ChatIdTagModel
.
model_validate
(
chat_id_tag
)
for
chat_id_tag
in
ChatIdTag
.
select
()
for
chat_id_tag
in
(
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
tag_name
==
tag_name
))
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
tag_name
=
tag_name
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
]
def
count_chat_ids_by_tag_name_and_user_id
(
def
count_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
self
,
tag_name
:
str
,
user_id
:
str
)
->
int
:
)
->
int
:
with
get_db
()
as
db
:
return
(
return
(
ChatIdTag
.
select
(
)
db
.
query
(
ChatIdTag
)
.
where
((
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
count
()
.
count
()
)
)
def
delete_tag_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
bool
:
def
delete_tag_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
bool
:
try
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
with
get_db
()
as
db
:
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
res
=
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
delete
()
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
if
tag_count
==
0
:
# Remove tag item from Tag col as well
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
db
.
commit
()
)
query
.
execute
()
# Remove the rows, return number of rows removed.
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
log
.
error
(
f
"delete_tag:
{
e
}
"
)
log
.
error
(
f
"delete_tag:
{
e
}
"
)
...
@@ -202,21 +235,23 @@ class TagTable:
...
@@ -202,21 +235,23 @@ class TagTable:
self
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
self
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
)
->
bool
:
try
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
with
get_db
()
as
db
:
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
chat_id
==
chat_id
)
res
=
(
&
(
ChatIdTag
.
user_id
==
user_id
)
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
chat_id
=
chat_id
,
user_id
=
user_id
)
.
delete
()
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
if
tag_count
==
0
:
# Remove tag item from Tag col as well
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
db
.
commit
()
)
query
.
execute
()
# Remove the rows, return number of rows removed.
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -234,4 +269,4 @@ class TagTable:
...
@@ -234,4 +269,4 @@ class TagTable:
return
True
return
True
Tags
=
TagTable
(
DB
)
Tags
=
TagTable
()
backend/apps/webui/models/tools.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
typing
import
List
,
Optional
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
import
time
import
time
import
logging
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
apps.webui.models.users
import
Users
from
apps.webui.models.users
import
Users
import
json
import
json
...
@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
...
@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
####################
class
Tool
(
Model
):
class
Tool
(
Base
):
id
=
CharField
(
unique
=
True
)
__tablename__
=
"tool"
user_id
=
CharField
()
name
=
TextField
()
content
=
TextField
()
specs
=
JSONField
()
meta
=
JSONField
()
valves
=
JSONField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Meta
:
id
=
Column
(
String
,
primary_key
=
True
)
database
=
DB
user_id
=
Column
(
String
)
name
=
Column
(
Text
)
content
=
Column
(
Text
)
specs
=
Column
(
JSONField
)
meta
=
Column
(
JSONField
)
valves
=
Column
(
JSONField
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
ToolMeta
(
BaseModel
):
class
ToolMeta
(
BaseModel
):
...
@@ -51,6 +50,8 @@ class ToolModel(BaseModel):
...
@@ -51,6 +50,8 @@ class ToolModel(BaseModel):
updated_at
:
int
# timestamp in epoch
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
####################
# Forms
# Forms
...
@@ -78,13 +79,13 @@ class ToolValves(BaseModel):
...
@@ -78,13 +79,13 @@ class ToolValves(BaseModel):
class
ToolsTable
:
class
ToolsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Tool
])
def
insert_new_tool
(
def
insert_new_tool
(
self
,
user_id
:
str
,
form_data
:
ToolForm
,
specs
:
List
[
dict
]
self
,
user_id
:
str
,
form_data
:
ToolForm
,
specs
:
List
[
dict
]
)
->
Optional
[
ToolModel
]:
)
->
Optional
[
ToolModel
]:
with
get_db
()
as
db
:
tool
=
ToolModel
(
tool
=
ToolModel
(
**
{
**
{
**
form_data
.
model_dump
(),
**
form_data
.
model_dump
(),
...
@@ -96,9 +97,12 @@ class ToolsTable:
...
@@ -96,9 +97,12 @@ class ToolsTable:
)
)
try
:
try
:
result
=
Tool
.
create
(
**
tool
.
model_dump
())
result
=
Tool
(
**
tool
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
tool
return
ToolModel
.
model_validate
(
result
)
else
:
else
:
return
None
return
None
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -107,17 +111,22 @@ class ToolsTable:
...
@@ -107,17 +111,22 @@ class ToolsTable:
def
get_tool_by_id
(
self
,
id
:
str
)
->
Optional
[
ToolModel
]:
def
get_tool_by_id
(
self
,
id
:
str
)
->
Optional
[
ToolModel
]:
try
:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
with
get_db
()
as
db
:
return
ToolModel
(
**
model_to_dict
(
tool
))
tool
=
db
.
get
(
Tool
,
id
)
return
ToolModel
.
model_validate
(
tool
)
except
:
except
:
return
None
return
None
def
get_tools
(
self
)
->
List
[
ToolModel
]:
def
get_tools
(
self
)
->
List
[
ToolModel
]:
return
[
ToolModel
(
**
model_to_dict
(
tool
))
for
tool
in
Tool
.
select
()]
with
get_db
()
as
db
:
return
[
ToolModel
.
model_validate
(
tool
)
for
tool
in
db
.
query
(
Tool
).
all
()]
def
get_tool_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
def
get_tool_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
with
get_db
()
as
db
:
tool
=
db
.
get
(
Tool
,
id
)
return
tool
.
valves
if
tool
.
valves
else
{}
return
tool
.
valves
if
tool
.
valves
else
{}
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
print
(
f
"An error occurred:
{
e
}
"
)
...
@@ -125,14 +134,13 @@ class ToolsTable:
...
@@ -125,14 +134,13 @@ class ToolsTable:
def
update_tool_valves_by_id
(
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
ToolValves
]:
def
update_tool_valves_by_id
(
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
ToolValves
]:
try
:
try
:
query
=
Tool
.
update
(
with
get_db
()
as
db
:
**
{
"valves"
:
valves
},
updated_at
=
int
(
time
.
time
()),
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
).
where
(
Tool
.
id
==
id
)
{
"valves"
:
valves
,
"updated_at"
:
int
(
time
.
time
())}
query
.
execute
()
)
db
.
commit
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
self
.
get_tool_by_id
(
id
)
return
ToolValves
(
**
model_to_dict
(
tool
))
except
:
except
:
return
None
return
None
...
@@ -141,7 +149,7 @@ class ToolsTable:
...
@@ -141,7 +149,7 @@ class ToolsTable:
)
->
Optional
[
dict
]:
)
->
Optional
[
dict
]:
try
:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "tools" and "valves" settings
# Check if user has "tools" and "valves" settings
if
"tools"
not
in
user_settings
:
if
"tools"
not
in
user_settings
:
...
@@ -159,7 +167,7 @@ class ToolsTable:
...
@@ -159,7 +167,7 @@ class ToolsTable:
)
->
Optional
[
dict
]:
)
->
Optional
[
dict
]:
try
:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "tools" and "valves" settings
# Check if user has "tools" and "valves" settings
if
"tools"
not
in
user_settings
:
if
"tools"
not
in
user_settings
:
...
@@ -179,25 +187,27 @@ class ToolsTable:
...
@@ -179,25 +187,27 @@ class ToolsTable:
def
update_tool_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
ToolModel
]:
def
update_tool_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
ToolModel
]:
try
:
try
:
query
=
Tool
.
update
(
with
get_db
()
as
db
:
**
updated
,
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
updated_at
=
int
(
time
.
time
()),
{
**
updated
,
"updated_at"
:
int
(
time
.
time
())}
).
where
(
Tool
.
id
==
id
)
)
query
.
execute
()
db
.
commit
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
tool
=
db
.
query
(
Tool
).
get
(
id
)
return
ToolModel
(
**
model_to_dict
(
tool
))
db
.
refresh
(
tool
)
return
ToolModel
.
model_validate
(
tool
)
except
:
except
:
return
None
return
None
def
delete_tool_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_tool_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
try
:
query
=
Tool
.
delete
().
where
((
Tool
.
id
==
id
))
with
get_db
()
as
db
:
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
return
True
except
:
except
:
return
False
return
False
Tools
=
ToolsTable
(
DB
)
Tools
=
ToolsTable
()
backend/apps/webui/models/users.py
View file @
f9e3c47d
from
pydantic
import
BaseModel
,
ConfigDict
from
pydantic
import
BaseModel
,
ConfigDict
,
parse_obj_as
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
import
time
import
time
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
utils.misc
import
get_gravatar_url
from
utils.misc
import
get_gravatar_url
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
,
Session
,
get_db
from
apps.webui.models.chats
import
Chats
from
apps.webui.models.chats
import
Chats
####################
####################
...
@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
...
@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
####################
####################
class
User
(
Model
):
class
User
(
Base
):
id
=
CharField
(
unique
=
True
)
__tablename__
=
"user"
name
=
CharField
()
email
=
CharField
()
role
=
CharField
()
profile_image_url
=
TextField
()
last_active_at
=
BigIntegerField
()
id
=
Column
(
String
,
primary_key
=
True
)
updated_at
=
BigIntegerField
()
name
=
Column
(
String
)
created_at
=
BigIntegerField
()
email
=
Column
(
String
)
role
=
Column
(
String
)
profile_image_url
=
Column
(
Text
)
api_key
=
CharField
(
null
=
True
,
unique
=
True
)
last_active_at
=
Column
(
BigInteger
)
settings
=
JSONField
(
null
=
True
)
updated_at
=
Column
(
BigInteger
)
info
=
JSONField
(
null
=
True
)
created_at
=
Column
(
BigInteger
)
oauth_sub
=
TextField
(
null
=
True
,
unique
=
True
)
api_key
=
Column
(
String
,
nullable
=
True
,
unique
=
True
)
settings
=
Column
(
JSONField
,
nullable
=
True
)
info
=
Column
(
JSONField
,
nullable
=
True
)
class
Meta
:
oauth_sub
=
Column
(
Text
,
unique
=
True
)
database
=
DB
class
UserSettings
(
BaseModel
):
class
UserSettings
(
BaseModel
):
...
@@ -57,6 +57,8 @@ class UserModel(BaseModel):
...
@@ -57,6 +57,8 @@ class UserModel(BaseModel):
oauth_sub
:
Optional
[
str
]
=
None
oauth_sub
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
####################
# Forms
# Forms
...
@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
...
@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
class
UsersTable
:
class
UsersTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
User
])
def
insert_new_user
(
def
insert_new_user
(
self
,
self
,
...
@@ -89,6 +88,7 @@ class UsersTable:
...
@@ -89,6 +88,7 @@ class UsersTable:
role
:
str
=
"pending"
,
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
)
->
Optional
[
UserModel
]:
with
get_db
()
as
db
:
user
=
UserModel
(
user
=
UserModel
(
**
{
**
{
"id"
:
id
,
"id"
:
id
,
...
@@ -102,7 +102,10 @@ class UsersTable:
...
@@ -102,7 +102,10 @@ class UsersTable:
"oauth_sub"
:
oauth_sub
,
"oauth_sub"
:
oauth_sub
,
}
}
)
)
result
=
User
.
create
(
**
user
.
model_dump
())
result
=
User
(
**
user
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
if
result
:
return
user
return
user
else
:
else
:
...
@@ -110,56 +113,67 @@ class UsersTable:
...
@@ -110,56 +113,67 @@ class UsersTable:
def
get_user_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
with
get_db
()
as
db
:
return
UserModel
(
**
model_to_dict
(
user
))
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
except
:
return
UserModel
.
model_validate
(
user
)
except
Exception
as
e
:
return
None
return
None
def
get_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
try
:
try
:
user
=
User
.
get
(
User
.
api_key
==
api_key
)
with
get_db
()
as
db
:
return
UserModel
(
**
model_to_dict
(
user
))
user
=
db
.
query
(
User
).
filter_by
(
api_key
=
api_key
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
except
:
return
None
return
None
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
try
:
try
:
user
=
User
.
get
(
User
.
email
==
email
)
with
get_db
()
as
db
:
return
UserModel
(
**
model_to_dict
(
user
))
user
=
db
.
query
(
User
).
filter_by
(
email
=
email
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
except
:
return
None
return
None
def
get_user_by_oauth_sub
(
self
,
sub
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_oauth_sub
(
self
,
sub
:
str
)
->
Optional
[
UserModel
]:
try
:
try
:
user
=
User
.
get
(
User
.
oauth_sub
==
sub
)
with
get_db
()
as
db
:
return
UserModel
(
**
model_to_dict
(
user
))
user
=
db
.
query
(
User
).
filter_by
(
oauth_sub
=
sub
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
except
:
return
None
return
None
def
get_users
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
def
get_users
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
return
[
with
get_db
()
as
db
:
UserModel
(
**
model_to_dict
(
user
))
users
=
(
for
user
in
User
.
select
()
db
.
query
(
User
)
# .limit(limit).offset(skip)
# .offset(skip).limit(limit)
]
.
all
()
)
return
[
UserModel
.
model_validate
(
user
)
for
user
in
users
]
def
get_num_users
(
self
)
->
Optional
[
int
]:
def
get_num_users
(
self
)
->
Optional
[
int
]:
return
User
.
select
().
count
()
with
get_db
()
as
db
:
return
db
.
query
(
User
).
count
()
def
get_first_user
(
self
)
->
UserModel
:
def
get_first_user
(
self
)
->
UserModel
:
try
:
try
:
user
=
User
.
select
().
order_by
(
User
.
created_at
).
first
()
with
get_db
()
as
db
:
return
UserModel
(
**
model_to_dict
(
user
))
user
=
db
.
query
(
User
).
order_by
(
User
.
created_at
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
except
:
return
None
return
None
def
update_user_role_by_id
(
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
def
update_user_role_by_id
(
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
try
:
try
:
query
=
User
.
update
(
role
=
role
).
where
(
User
.
id
==
id
)
with
get_db
()
as
db
:
query
.
execute
(
)
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"role"
:
role
}
)
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
return
UserModel
.
model_
validate
(
user
)
except
:
except
:
return
None
return
None
...
@@ -167,23 +181,28 @@ class UsersTable:
...
@@ -167,23 +181,28 @@ class UsersTable:
self
,
id
:
str
,
profile_image_url
:
str
self
,
id
:
str
,
profile_image_url
:
str
)
->
Optional
[
UserModel
]:
)
->
Optional
[
UserModel
]:
try
:
try
:
query
=
User
.
update
(
profile_image_url
=
profile_image_url
).
where
(
with
get_db
()
as
db
:
User
.
id
==
id
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"profile_image_url"
:
profile_image_url
}
)
)
query
.
execute
()
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
return
UserModel
.
model_
validate
(
user
)
except
:
except
:
return
None
return
None
def
update_user_last_active_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
def
update_user_last_active_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
try
:
query
=
User
.
update
(
last_active_at
=
int
(
time
.
time
())).
where
(
User
.
id
==
id
)
with
get_db
()
as
db
:
query
.
execute
()
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"last_active_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
return
UserModel
.
model_
validate
(
user
)
except
:
except
:
return
None
return
None
...
@@ -191,22 +210,25 @@ class UsersTable:
...
@@ -191,22 +210,25 @@ class UsersTable:
self
,
id
:
str
,
oauth_sub
:
str
self
,
id
:
str
,
oauth_sub
:
str
)
->
Optional
[
UserModel
]:
)
->
Optional
[
UserModel
]:
try
:
try
:
query
=
User
.
update
(
oauth_sub
=
oauth_sub
).
where
(
User
.
id
==
id
)
with
get_db
()
as
db
:
query
.
execute
()
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"oauth_sub"
:
oauth_sub
})
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
return
UserModel
.
model_
validate
(
user
)
except
:
except
:
return
None
return
None
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
try
:
try
:
query
=
User
.
update
(
**
updated
).
where
(
User
.
id
==
id
)
with
get_db
()
as
db
:
query
.
execute
()
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
updated
)
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
except
:
return
UserModel
.
model_validate
(
user
)
# return UserModel(**user.dict())
except
Exception
as
e
:
return
None
return
None
def
delete_user_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_user_by_id
(
self
,
id
:
str
)
->
bool
:
...
@@ -215,9 +237,10 @@ class UsersTable:
...
@@ -215,9 +237,10 @@ class UsersTable:
result
=
Chats
.
delete_chats_by_user_id
(
id
)
result
=
Chats
.
delete_chats_by_user_id
(
id
)
if
result
:
if
result
:
with
get_db
()
as
db
:
# Delete User
# Delete User
query
=
User
.
delete
().
where
(
User
.
id
==
id
)
db
.
query
(
User
).
filter_by
(
id
=
id
).
delete
(
)
query
.
execute
()
# Remove the rows, return number of rows removed.
db
.
commit
()
return
True
return
True
else
:
else
:
...
@@ -227,19 +250,20 @@ class UsersTable:
...
@@ -227,19 +250,20 @@ class UsersTable:
def
update_user_api_key_by_id
(
self
,
id
:
str
,
api_key
:
str
)
->
str
:
def
update_user_api_key_by_id
(
self
,
id
:
str
,
api_key
:
str
)
->
str
:
try
:
try
:
query
=
User
.
update
(
api_key
=
api_key
).
where
(
User
.
id
==
id
)
with
get_db
()
as
db
:
result
=
query
.
execute
(
)
result
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"api_key"
:
api_key
}
)
db
.
commit
()
return
True
if
result
==
1
else
False
return
True
if
result
==
1
else
False
except
:
except
:
return
False
return
False
def
get_user_api_key_by_id
(
self
,
id
:
str
)
->
Optional
[
str
]:
def
get_user_api_key_by_id
(
self
,
id
:
str
)
->
Optional
[
str
]:
try
:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
user
.
api_key
return
user
.
api_key
except
:
except
Exception
as
e
:
return
None
return
None
Users
=
UsersTable
(
DB
)
Users
=
UsersTable
()
backend/apps/webui/routers/chats.py
View file @
f9e3c47d
...
@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
...
@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@
router
.
get
(
"/list/user/{user_id}"
,
response_model
=
List
[
ChatTitleIdResponse
])
@
router
.
get
(
"/list/user/{user_id}"
,
response_model
=
List
[
ChatTitleIdResponse
])
async
def
get_user_chat_list_by_user_id
(
async
def
get_user_chat_list_by_user_id
(
user_id
:
str
,
user
=
Depends
(
get_admin_user
),
skip
:
int
=
0
,
limit
:
int
=
50
user_id
:
str
,
user
=
Depends
(
get_admin_user
),
skip
:
int
=
0
,
limit
:
int
=
50
,
):
):
return
Chats
.
get_chat_list_by_user_id
(
return
Chats
.
get_chat_list_by_user_id
(
user_id
,
include_archived
=
True
,
skip
=
skip
,
limit
=
limit
user_id
,
include_archived
=
True
,
skip
=
skip
,
limit
=
limit
...
@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
...
@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
@
router
.
get
(
"/all/archived"
,
response_model
=
List
[
ChatResponse
])
@
router
.
get
(
"/all/archived"
,
response_model
=
List
[
ChatResponse
])
async
def
get_user_chats
(
user
=
Depends
(
get_verified_user
)):
async
def
get_user_
archived_
chats
(
user
=
Depends
(
get_verified_user
)):
return
[
return
[
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
for
chat
in
Chats
.
get_archived_chats_by_user_id
(
user
.
id
)
for
chat
in
Chats
.
get_archived_chats_by_user_id
(
user
.
id
)
...
@@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name(
...
@@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name(
form_data
:
TagNameForm
,
user
=
Depends
(
get_verified_user
)
form_data
:
TagNameForm
,
user
=
Depends
(
get_verified_user
)
):
):
print
(
form_data
)
chat_ids
=
[
chat_ids
=
[
chat_id_tag
.
chat_id
chat_id_tag
.
chat_id
for
chat_id_tag
in
Tags
.
get_chat_ids_by_tag_name_and_user_id
(
for
chat_id_tag
in
Tags
.
get_chat_ids_by_tag_name_and_user_id
(
...
...
backend/apps/webui/routers/documents.py
View file @
f9e3c47d
...
@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
...
@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
@
router
.
post
(
"/doc/update"
,
response_model
=
Optional
[
DocumentResponse
])
@
router
.
post
(
"/doc/update"
,
response_model
=
Optional
[
DocumentResponse
])
async
def
update_doc_by_name
(
async
def
update_doc_by_name
(
name
:
str
,
form_data
:
DocumentUpdateForm
,
user
=
Depends
(
get_admin_user
)
name
:
str
,
form_data
:
DocumentUpdateForm
,
user
=
Depends
(
get_admin_user
),
):
):
doc
=
Documents
.
update_doc_by_name
(
name
,
form_data
)
doc
=
Documents
.
update_doc_by_name
(
name
,
form_data
)
if
doc
:
if
doc
:
...
...
Prev
1
2
3
4
5
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment