Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
d0e89a03
Unverified
Commit
d0e89a03
authored
Jul 02, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Jul 02, 2024
Browse files
Merge pull request #3327 from jonathan-rohde/feat/sqlalchemy-instead-of-peewee
BREAKING CHANGE/sqlalchemy instead of peewee
parents
2c061777
2aecd7d0
Changes
60
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
577 additions
and
714 deletions
+577
-714
backend/apps/webui/internal/migrations/README.md
backend/apps/webui/internal/migrations/README.md
+0
-21
backend/apps/webui/internal/wrappers.py
backend/apps/webui/internal/wrappers.py
+0
-72
backend/apps/webui/main.py
backend/apps/webui/main.py
+1
-1
backend/apps/webui/models/auths.py
backend/apps/webui/models/auths.py
+22
-27
backend/apps/webui/models/chats.py
backend/apps/webui/models/chats.py
+109
-142
backend/apps/webui/models/documents.py
backend/apps/webui/models/documents.py
+42
-48
backend/apps/webui/models/files.py
backend/apps/webui/models/files.py
+25
-27
backend/apps/webui/models/functions.py
backend/apps/webui/models/functions.py
+60
-62
backend/apps/webui/models/memories.py
backend/apps/webui/models/memories.py
+34
-40
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+32
-36
backend/apps/webui/models/prompts.py
backend/apps/webui/models/prompts.py
+30
-41
backend/apps/webui/models/tags.py
backend/apps/webui/models/tags.py
+85
-70
backend/apps/webui/models/tools.py
backend/apps/webui/models/tools.py
+40
-46
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+74
-70
backend/apps/webui/routers/chats.py
backend/apps/webui/routers/chats.py
+5
-2
backend/apps/webui/routers/documents.py
backend/apps/webui/routers/documents.py
+3
-1
backend/apps/webui/routers/files.py
backend/apps/webui/routers/files.py
+1
-4
backend/apps/webui/routers/memories.py
backend/apps/webui/routers/memories.py
+3
-1
backend/apps/webui/routers/models.py
backend/apps/webui/routers/models.py
+8
-2
backend/apps/webui/routers/prompts.py
backend/apps/webui/routers/prompts.py
+3
-1
No files found.
backend/apps/webui/internal/migrations/README.md
deleted
100644 → 0
View file @
2c061777
# Database Migrations
This directory contains all the database migrations for the web app.
Migrations are done using the
[
`peewee-migrate`
](
https://github.com/klen/peewee_migrate
)
library.
Migrations are automatically ran at app startup.
## Creating a migration
Have you made a change to the schema of an existing model?
You will need to create a migration file to ensure that existing databases are updated for backwards compatibility.
1.
Have a database file (
`webui.db`
) that has the old schema prior to any of your changes.
2.
Make your changes to the models.
3.
From the
`backend`
directory, run the following command:
```
bash
pw_migrate create
--auto
--auto-source
apps.webui.models
--database
sqlite:///
${
SQLITE_DB
}
--directory
apps/web/internal/migrations
${
MIGRATION_NAME
}
```
-
`$SQLITE_DB`
should be the path to the database file.
-
`$MIGRATION_NAME`
should be a descriptive name for the migration.
4.
The migration file will be created in the
`apps/web/internal/migrations`
directory.
backend/apps/webui/internal/wrappers.py
deleted
100644 → 0
View file @
2c061777
from
contextvars
import
ContextVar
from
peewee
import
*
from
peewee
import
PostgresqlDatabase
,
InterfaceError
as
PeeWeeInterfaceError
import
logging
from
playhouse.db_url
import
connect
,
parse
from
playhouse.shortcuts
import
ReconnectMixin
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
db_state_default
=
{
"closed"
:
None
,
"conn"
:
None
,
"ctx"
:
None
,
"transactions"
:
None
}
db_state
=
ContextVar
(
"db_state"
,
default
=
db_state_default
.
copy
())
class
PeeweeConnectionState
(
object
):
def
__init__
(
self
,
**
kwargs
):
super
().
__setattr__
(
"_state"
,
db_state
)
super
().
__init__
(
**
kwargs
)
def
__setattr__
(
self
,
name
,
value
):
self
.
_state
.
get
()[
name
]
=
value
def
__getattr__
(
self
,
name
):
value
=
self
.
_state
.
get
()[
name
]
return
value
class
CustomReconnectMixin
(
ReconnectMixin
):
reconnect_errors
=
(
# psycopg2
(
OperationalError
,
"termin"
),
(
InterfaceError
,
"closed"
),
# peewee
(
PeeWeeInterfaceError
,
"closed"
),
)
class
ReconnectingPostgresqlDatabase
(
CustomReconnectMixin
,
PostgresqlDatabase
):
pass
def
register_connection
(
db_url
):
db
=
connect
(
db_url
)
if
isinstance
(
db
,
PostgresqlDatabase
):
# Enable autoconnect for SQLite databases, managed by Peewee
db
.
autoconnect
=
True
db
.
reuse_if_open
=
True
log
.
info
(
"Connected to PostgreSQL database"
)
# Get the connection details
connection
=
parse
(
db_url
)
# Use our custom database class that supports reconnection
db
=
ReconnectingPostgresqlDatabase
(
connection
[
"database"
],
user
=
connection
[
"user"
],
password
=
connection
[
"password"
],
host
=
connection
[
"host"
],
port
=
connection
[
"port"
],
)
db
.
connect
(
reuse_if_open
=
True
)
elif
isinstance
(
db
,
SqliteDatabase
):
# Enable autoconnect for SQLite databases, managed by Peewee
db
.
autoconnect
=
True
db
.
reuse_if_open
=
True
log
.
info
(
"Connected to SQLite database"
)
else
:
raise
ValueError
(
"Unsupported database connection"
)
return
db
backend/apps/webui/main.py
View file @
d0e89a03
...
...
@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
from
fastapi.responses
import
StreamingResponse
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
sqlalchemy.orm
import
Session
from
apps.webui.routers
import
(
auths
,
users
,
...
...
backend/apps/webui/models/auths.py
View file @
d0e89a03
from
pydantic
import
BaseModel
from
typing
import
List
,
Union
,
Optional
import
time
from
typing
import
Optional
import
uuid
import
logging
from
peewee
import
*
from
sqlalchemy
import
String
,
Column
,
Boolean
,
Text
from
apps.webui.models.users
import
UserModel
,
Users
from
utils.utils
import
verify_password
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
Session
from
config
import
SRC_LOG_LEVELS
...
...
@@ -20,14 +19,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Auth
(
Model
):
id
=
CharField
(
unique
=
True
)
email
=
CharField
()
password
=
TextField
()
active
=
BooleanField
()
class
Auth
(
Base
):
__tablename__
=
"auth"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
email
=
Column
(
String
)
password
=
Column
(
Text
)
active
=
Column
(
Boolean
)
class
AuthModel
(
BaseModel
):
...
...
@@ -94,9 +92,6 @@ class AddUserForm(SignupForm):
class
AuthsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Auth
])
def
insert_new_auth
(
self
,
...
...
@@ -114,12 +109,16 @@ class AuthsTable:
auth
=
AuthModel
(
**
{
"id"
:
id
,
"email"
:
email
,
"password"
:
password
,
"active"
:
True
}
)
result
=
Auth
.
create
(
**
auth
.
model_dump
())
result
=
Auth
(
**
auth
.
model_dump
())
Session
.
add
(
result
)
user
=
Users
.
insert_new_user
(
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
and
user
:
return
user
else
:
...
...
@@ -128,7 +127,7 @@ class AuthsTable:
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user:
{
email
}
"
)
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
auth
=
Session
.
query
(
Auth
).
filter_by
(
email
=
email
,
active
=
True
)
.
first
()
if
auth
:
if
verify_password
(
password
,
auth
.
password
):
user
=
Users
.
get_user_by_id
(
auth
.
id
)
...
...
@@ -155,7 +154,7 @@ class AuthsTable:
def
authenticate_user_by_trusted_header
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user_by_trusted_header:
{
email
}
"
)
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
auth
=
Session
.
query
(
Auth
).
filter
(
email
=
email
,
active
=
True
)
.
first
()
if
auth
:
user
=
Users
.
get_user_by_id
(
auth
.
id
)
return
user
...
...
@@ -164,18 +163,16 @@ class AuthsTable:
def
update_user_password_by_id
(
self
,
id
:
str
,
new_password
:
str
)
->
bool
:
try
:
query
=
Auth
.
update
(
password
=
new_password
).
where
(
Auth
.
id
==
id
)
result
=
query
.
execute
(
)
result
=
(
Session
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"password"
:
new_password
}
)
)
return
True
if
result
==
1
else
False
except
:
return
False
def
update_email_by_id
(
self
,
id
:
str
,
email
:
str
)
->
bool
:
try
:
query
=
Auth
.
update
(
email
=
email
).
where
(
Auth
.
id
==
id
)
result
=
query
.
execute
()
result
=
Session
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"email"
:
email
})
return
True
if
result
==
1
else
False
except
:
return
False
...
...
@@ -186,9 +183,7 @@ class AuthsTable:
result
=
Users
.
delete_user_by_id
(
id
)
if
result
:
# Delete Auth
query
=
Auth
.
delete
().
where
(
Auth
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Auth
).
filter_by
(
id
=
id
).
delete
()
return
True
else
:
...
...
@@ -197,4 +192,4 @@ class AuthsTable:
return
False
Auths
=
AuthsTable
(
DB
)
Auths
=
AuthsTable
()
backend/apps/webui/models/chats.py
View file @
d0e89a03
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
import
json
import
uuid
import
time
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Boolean
,
Text
from
apps.webui.internal.db
import
Base
,
Session
####################
# Chat DB Schema
####################
class
Chat
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
title
=
TextField
()
chat
=
TextField
()
# Save Chat JSON as Text
class
Chat
(
Base
):
__tablename__
=
"chat"
created_at
=
BigIntegerField
()
updated_at
=
BigIntegerField
()
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
title
=
Column
(
Text
)
chat
=
Column
(
Text
)
# Save Chat JSON as Text
share_id
=
CharField
(
null
=
True
,
unique
=
True
)
archived
=
BooleanField
(
default
=
False
)
created_at
=
Column
(
BigInteger
)
updated_at
=
Column
(
BigInteger
)
class
Meta
:
database
=
DB
share_id
=
Column
(
Text
,
unique
=
True
,
nullable
=
True
)
archived
=
Column
(
Boolean
,
default
=
False
)
class
ChatModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
id
:
str
user_id
:
str
title
:
str
...
...
@@ -75,9 +77,6 @@ class ChatTitleIdResponse(BaseModel):
class
ChatTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
db
.
create_tables
([
Chat
])
def
insert_new_chat
(
self
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
id
=
str
(
uuid
.
uuid4
())
...
...
@@ -94,26 +93,28 @@ class ChatTable:
}
)
result
=
Chat
.
create
(
**
chat
.
model_dump
())
return
chat
if
result
else
None
result
=
Chat
(
**
chat
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
return
ChatModel
.
model_validate
(
result
)
if
result
else
None
def
update_chat_by_id
(
self
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
try
:
query
=
Chat
.
update
(
chat
=
json
.
dumps
(
chat
),
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
,
updated_at
=
int
(
time
.
time
()),
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
except
:
chat_obj
=
Session
.
get
(
Chat
,
id
)
chat_obj
.
chat
=
json
.
dumps
(
chat
)
chat_obj
.
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
chat_obj
.
updated_at
=
int
(
time
.
time
())
Session
.
commit
()
Session
.
refresh
(
chat_obj
)
return
ChatModel
.
model_validate
(
chat_obj
)
except
Exception
as
e
:
return
None
def
insert_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
# Get the existing chat to share
chat
=
Chat
.
get
(
Chat
.
id
==
chat_id
)
chat
=
Session
.
get
(
Chat
,
chat_id
)
# Check if the chat is already shared
if
chat
.
share_id
:
return
self
.
get_chat_by_id_and_user_id
(
chat
.
share_id
,
"shared"
)
...
...
@@ -128,10 +129,15 @@ class ChatTable:
"updated_at"
:
int
(
time
.
time
()),
}
)
shared_result
=
Chat
.
create
(
**
shared_chat
.
model_dump
())
shared_result
=
Chat
(
**
shared_chat
.
model_dump
())
Session
.
add
(
shared_result
)
Session
.
commit
()
Session
.
refresh
(
shared_result
)
# Update the original chat with the share_id
result
=
(
Chat
.
update
(
share_id
=
shared_chat
.
id
).
where
(
Chat
.
id
==
chat_id
).
execute
()
Session
.
query
(
Chat
)
.
filter_by
(
id
=
chat_id
)
.
update
({
"share_id"
:
shared_chat
.
id
})
)
return
shared_chat
if
(
shared_result
and
result
)
else
None
...
...
@@ -139,26 +145,20 @@ class ChatTable:
def
update_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
print
(
"update_shared_chat_by_id"
)
chat
=
Chat
.
get
(
Chat
.
id
==
chat_id
)
chat
=
Session
.
get
(
Chat
,
chat_id
)
print
(
chat
)
chat
.
title
=
chat
.
title
chat
.
chat
=
chat
.
chat
Session
.
commit
()
Session
.
refresh
(
chat
)
query
=
Chat
.
update
(
title
=
chat
.
title
,
chat
=
chat
.
chat
,
).
where
(
Chat
.
id
==
chat
.
share_id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
chat
.
share_id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
return
self
.
get_chat_by_id
(
chat
.
share_id
)
except
:
return
None
def
delete_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
bool
:
try
:
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
==
f
"shared-
{
chat_id
}
"
)
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Chat
).
filter_by
(
user_id
=
f
"shared-
{
chat_id
}
"
).
delete
()
return
True
except
:
return
False
...
...
@@ -167,40 +167,27 @@ class ChatTable:
self
,
id
:
str
,
share_id
:
Optional
[
str
]
)
->
Optional
[
ChatModel
]:
try
:
query
=
Chat
.
update
(
share_id
=
share_id
,
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
chat
=
Session
.
get
(
Chat
,
id
)
chat
.
share_id
=
share_id
Session
.
commit
()
Session
.
refresh
(
chat
)
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
toggle_chat_archive_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
self
.
get_chat_by_id
(
id
)
query
=
Chat
.
update
(
archived
=
(
not
chat
.
archived
),
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
chat
=
Session
.
get
(
Chat
,
id
)
chat
.
archived
=
not
chat
.
archived
Session
.
commit
()
Session
.
refresh
(
chat
)
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
archive_all_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
chats
=
self
.
get_chats_by_user_id
(
user_id
)
for
chat
in
chats
:
query
=
Chat
.
update
(
archived
=
True
,
).
where
(
Chat
.
id
==
chat
.
id
)
query
.
execute
()
Session
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
update
({
"archived"
:
True
})
return
True
except
:
return
False
...
...
@@ -208,15 +195,14 @@ class ChatTable:
def
get_archived_chat_list_by_user_id
(
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
True
)
.
where
(
Chat
.
user_id
==
user_id
)
all_chats
=
(
Session
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
# .limit(limit).offset(skip)
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_list_by_user_id
(
self
,
...
...
@@ -225,92 +211,80 @@ class ChatTable:
skip
:
int
=
0
,
limit
:
int
=
50
,
)
->
List
[
ChatModel
]:
if
include_archived
:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
else
:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
False
)
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit)
# .offset(skip)
]
query
=
Session
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
)
if
not
include_archived
:
query
=
query
.
filter_by
(
archived
=
False
)
all_chats
=
(
query
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit).offset(skip)
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_list_by_chat_ids
(
self
,
chat_ids
:
List
[
str
],
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
False
)
.
where
(
Chat
.
id
.
in_
(
chat_ids
))
all_chats
=
(
Session
.
query
(
Chat
)
.
filter
(
Chat
.
id
.
in_
(
chat_ids
))
.
filter_by
(
archived
=
False
)
.
order_by
(
Chat
.
updated_at
.
desc
())
]
.
all
()
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_
to_dict
(
chat
)
)
chat
=
Session
.
get
(
Chat
,
id
)
return
ChatModel
.
model_
validate
(
chat
)
except
:
return
None
def
get_chat_by_share_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
Chat
.
get
(
Chat
.
share_id
==
id
)
chat
=
Session
.
query
(
Chat
).
filter_by
(
share_id
=
id
).
first
(
)
if
chat
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
return
self
.
get_chat_by_id
(
id
)
else
:
return
None
except
:
except
Exception
as
e
:
return
None
def
get_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
,
Chat
.
user_id
==
user_id
)
return
ChatModel
(
**
model_
to_dict
(
chat
)
)
chat
=
Session
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
)
.
first
()
return
ChatModel
.
model_
validate
(
chat
)
except
:
return
None
def
get_chats
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
().
order_by
(
Chat
.
updated_at
.
desc
())
all_chats
=
(
Session
.
query
(
Chat
)
# .limit(limit).offset(skip)
]
.
order_by
(
Chat
.
updated_at
.
desc
())
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
user_id
==
user_id
)
all_chats
=
(
Session
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit).offset(skip
)
]
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_archived_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
return
[
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
()
.
where
(
Chat
.
archived
==
True
)
.
where
(
Chat
.
user_id
==
user_id
)
all_chats
=
(
Session
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
]
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
delete_chat_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Chat
).
filter_by
(
id
=
id
).
delete
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
...
...
@@ -318,8 +292,7 @@ class ChatTable:
def
delete_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
)
&
(
Chat
.
user_id
==
user_id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
...
...
@@ -327,29 +300,23 @@ class ChatTable:
def
delete_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
self
.
delete_shared_chats_by_user_id
(
user_id
)
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
==
user_id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
def
delete_shared_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
Chat
.
select
().
where
(
Chat
.
user_id
==
user_id
)
]
chats_by_user
=
Session
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
all
()
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
chats_by_user
]
query
=
Chat
.
delete
().
where
(
Chat
.
user_id
<<
shared_chat_ids
)
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Chat
).
filter
(
Chat
.
user_id
.
in_
(
shared_chat_ids
)).
delete
()
return
True
except
:
return
False
Chats
=
ChatTable
(
DB
)
Chats
=
ChatTable
()
backend/apps/webui/models/documents.py
View file @
d0e89a03
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
import
logging
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
Session
import
json
...
...
@@ -22,20 +19,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Document
(
Model
):
collection_name
=
CharField
(
unique
=
True
)
name
=
CharField
(
unique
=
True
)
title
=
TextField
()
filename
=
TextField
()
content
=
TextField
(
null
=
True
)
user_id
=
CharField
()
timestamp
=
BigIntegerField
()
class
Document
(
Base
):
__tablename__
=
"document"
class
Meta
:
database
=
DB
collection_name
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
,
unique
=
True
)
title
=
Column
(
Text
)
filename
=
Column
(
Text
)
content
=
Column
(
Text
,
nullable
=
True
)
user_id
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
DocumentModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
collection_name
:
str
name
:
str
title
:
str
...
...
@@ -72,9 +70,6 @@ class DocumentForm(DocumentUpdateForm):
class
DocumentsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Document
])
def
insert_new_doc
(
self
,
user_id
:
str
,
form_data
:
DocumentForm
...
...
@@ -88,9 +83,12 @@ class DocumentsTable:
)
try
:
result
=
Document
.
create
(
**
document
.
model_dump
())
result
=
Document
(
**
document
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
d
ocument
return
D
ocument
Model
.
model_validate
(
result
)
else
:
return
None
except
:
...
...
@@ -98,31 +96,29 @@ class DocumentsTable:
def
get_doc_by_name
(
self
,
name
:
str
)
->
Optional
[
DocumentModel
]:
try
:
document
=
Document
.
get
(
Document
.
name
==
name
)
return
DocumentModel
(
**
model_
to_dict
(
document
)
)
document
=
Session
.
query
(
Document
).
filter_by
(
name
=
name
).
first
(
)
return
DocumentModel
.
model_
validate
(
document
)
if
document
else
None
except
:
return
None
def
get_docs
(
self
)
->
List
[
DocumentModel
]:
return
[
DocumentModel
(
**
model_to_dict
(
doc
))
for
doc
in
Document
.
select
()
# .limit(limit).offset(skip)
DocumentModel
.
model_validate
(
doc
)
for
doc
in
Session
.
query
(
Document
).
all
()
]
def
update_doc_by_name
(
self
,
name
:
str
,
form_data
:
DocumentUpdateForm
)
->
Optional
[
DocumentModel
]:
try
:
query
=
Document
.
update
(
title
=
form_data
.
title
,
name
=
form_data
.
nam
e
,
timestamp
=
int
(
time
.
time
())
,
).
where
(
Document
.
name
==
name
)
query
.
execute
()
doc
=
Document
.
get
(
Document
.
name
==
form_data
.
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
)
)
Session
.
query
(
Document
).
filter_by
(
name
=
name
)
.
update
(
{
"title"
:
form_data
.
titl
e
,
"name"
:
form_data
.
name
,
"timestamp"
:
int
(
time
.
time
()),
}
)
Session
.
commit
(
)
return
self
.
get_doc_by_name
(
form_data
.
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
return
None
...
...
@@ -135,26 +131,24 @@ class DocumentsTable:
doc_content
=
json
.
loads
(
doc
.
content
if
doc
.
content
else
"{}"
)
doc_content
=
{
**
doc_content
,
**
updated
}
query
=
Document
.
update
(
content
=
json
.
dumps
(
doc_content
),
timestamp
=
int
(
time
.
time
()
),
).
where
(
Document
.
name
==
name
)
query
.
execute
()
doc
=
Document
.
get
(
Document
.
name
==
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
)
)
Session
.
query
(
Document
).
filter_by
(
name
=
name
)
.
update
(
{
"content"
:
json
.
dumps
(
doc_content
),
"timestamp"
:
int
(
time
.
time
()),
}
)
Session
.
commit
(
)
return
self
.
get_doc_by_name
(
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
return
None
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
try
:
query
=
Document
.
delete
().
where
((
Document
.
name
==
name
))
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Document
).
filter_by
(
name
=
name
).
delete
()
return
True
except
:
return
False
Documents
=
DocumentsTable
(
DB
)
Documents
=
DocumentsTable
()
backend/apps/webui/models/files.py
View file @
d0e89a03
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Text
from
apps.webui.internal.db
import
JSONField
,
Base
,
Session
import
json
...
...
@@ -18,15 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
File
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
filename
=
TextField
()
meta
=
JSONField
()
created_at
=
BigIntegerField
()
class
File
(
Base
):
__tablename__
=
"file"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
filename
=
Column
(
Text
)
meta
=
Column
(
JSONField
)
created_at
=
Column
(
BigInteger
)
class
FileModel
(
BaseModel
):
...
...
@@ -36,6 +36,8 @@ class FileModel(BaseModel):
meta
:
dict
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -57,9 +59,6 @@ class FileForm(BaseModel):
class
FilesTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
File
])
def
insert_new_file
(
self
,
user_id
:
str
,
form_data
:
FileForm
)
->
Optional
[
FileModel
]:
file
=
FileModel
(
...
...
@@ -71,9 +70,12 @@ class FilesTable:
)
try
:
result
=
File
.
create
(
**
file
.
model_dump
())
result
=
File
(
**
file
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
f
ile
return
F
ile
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -82,31 +84,27 @@ class FilesTable:
def
get_file_by_id
(
self
,
id
:
str
)
->
Optional
[
FileModel
]:
try
:
file
=
File
.
get
(
File
.
id
==
id
)
return
FileModel
(
**
model_
to_dict
(
file
)
)
file
=
Session
.
get
(
File
,
id
)
return
FileModel
.
model_
validate
(
file
)
except
:
return
None
def
get_files
(
self
)
->
List
[
FileModel
]:
return
[
FileModel
(
**
model_
to_dict
(
file
)
)
for
file
in
File
.
select
()]
return
[
FileModel
.
model_
validate
(
file
)
for
file
in
Session
.
query
(
File
).
all
()]
def
delete_file_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
File
.
delete
().
where
((
File
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
File
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
def
delete_all_files
(
self
)
->
bool
:
try
:
query
=
File
.
delete
()
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
File
).
delete
()
return
True
except
:
return
False
Files
=
FilesTable
(
DB
)
Files
=
FilesTable
()
backend/apps/webui/models/functions.py
View file @
d0e89a03
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
Column
,
String
,
Text
,
BigInteger
,
Boolean
from
apps.webui.internal.db
import
JSONField
,
Base
,
Session
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -21,21 +22,20 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Function
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
name
=
TextField
()
type
=
TextField
()
content
=
TextField
()
meta
=
JSONField
()
valves
=
JSONField
()
is_active
=
BooleanField
(
default
=
False
)
is_global
=
BooleanField
(
default
=
False
)
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Function
(
Base
):
__tablename__
=
"function"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
name
=
Column
(
Text
)
type
=
Column
(
Text
)
content
=
Column
(
Text
)
meta
=
Column
(
JSONField
)
valves
=
Column
(
JSONField
)
is_active
=
Column
(
Boolean
)
is_global
=
Column
(
Boolean
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
FunctionMeta
(
BaseModel
):
...
...
@@ -55,6 +55,8 @@ class FunctionModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -85,9 +87,6 @@ class FunctionValves(BaseModel):
class
FunctionsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Function
])
def
insert_new_function
(
self
,
user_id
:
str
,
type
:
str
,
form_data
:
FunctionForm
...
...
@@ -103,9 +102,12 @@ class FunctionsTable:
)
try
:
result
=
Function
.
create
(
**
function
.
model_dump
())
result
=
Function
(
**
function
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
f
unction
return
F
unction
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -114,21 +116,21 @@ class FunctionsTable:
def
get_function_by_id
(
self
,
id
:
str
)
->
Optional
[
FunctionModel
]:
try
:
function
=
Funct
ion
.
get
(
Function
.
id
==
id
)
return
FunctionModel
(
**
model_
to_dict
(
function
)
)
function
=
Sess
ion
.
get
(
Function
,
id
)
return
FunctionModel
.
model_
validate
(
function
)
except
:
return
None
def
get_functions
(
self
,
active_only
=
False
)
->
List
[
FunctionModel
]:
if
active_only
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
().
wh
er
e
(
Function
.
is_active
==
True
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
Session
.
qu
er
y
(
Function
).
filter_by
(
is_active
=
True
)
.
all
()
]
else
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
()
FunctionModel
.
model_
validate
(
function
)
for
function
in
Session
.
query
(
Function
).
all
()
]
def
get_functions_by_type
(
...
...
@@ -136,15 +138,15 @@ class FunctionsTable:
)
->
List
[
FunctionModel
]:
if
active_only
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
().
where
(
Function
.
type
==
type
,
Function
.
is_active
==
True
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
Session
.
query
(
Function
)
.
filter_by
(
type
=
type
,
is_active
=
True
)
.
all
(
)
]
else
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
().
wh
er
e
(
Function
.
type
==
type
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
Session
.
qu
er
y
(
Function
).
filter_by
(
type
=
type
).
all
(
)
]
def
get_global_filter_functions
(
self
)
->
List
[
FunctionModel
]:
...
...
@@ -159,7 +161,7 @@ class FunctionsTable:
def
get_function_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
function
=
Funct
ion
.
get
(
Function
.
id
==
id
)
function
=
Sess
ion
.
get
(
Function
,
id
)
return
function
.
valves
if
function
.
valves
else
{}
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
...
...
@@ -169,14 +171,12 @@ class FunctionsTable:
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
FunctionValves
]:
try
:
query
=
Function
.
update
(
**
{
"valves"
:
valves
},
updated_at
=
int
(
time
.
time
()),
).
where
(
Function
.
id
==
id
)
query
.
execute
()
function
=
Function
.
get
(
Function
.
id
==
id
)
return
FunctionValves
(
**
model_to_dict
(
function
))
function
=
Session
.
get
(
Function
,
id
)
function
.
valves
=
valves
function
.
updated_at
=
int
(
time
.
time
())
Session
.
commit
()
Session
.
refresh
(
function
)
return
self
.
get_function_by_id
(
id
)
except
:
return
None
...
...
@@ -223,38 +223,36 @@ class FunctionsTable:
def
update_function_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
FunctionModel
]:
try
:
query
=
Function
.
update
(
**
updated
,
updated
_at
=
int
(
time
.
time
())
,
).
where
(
Function
.
id
==
id
)
query
.
execute
()
function
=
Function
.
get
(
Function
.
id
==
id
)
return
FunctionModel
(
**
model_to_dict
(
function
)
)
Session
.
query
(
Function
).
filter_by
(
id
=
id
)
.
update
(
{
**
updated
,
"updated_at"
:
int
(
time
.
time
()),
}
)
Session
.
commit
(
)
return
self
.
get_function_by_id
(
id
)
except
:
return
None
def
deactivate_all_functions
(
self
)
->
Optional
[
bool
]:
try
:
query
=
Function
.
update
(
**
{
"is_active"
:
False
},
updated_at
=
int
(
time
.
time
()),
Session
.
query
(
Function
).
update
(
{
"is_active"
:
False
,
"updated_at"
:
int
(
time
.
time
()),
}
)
query
.
execute
()
Session
.
commit
()
return
True
except
:
return
None
def
delete_function_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Function
.
delete
().
where
((
Function
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Function
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
Functions
=
FunctionsTable
(
DB
)
Functions
=
FunctionsTable
()
backend/apps/webui/models/memories.py
View file @
d0e89a03
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Union
,
Optional
from
apps.webui.internal.db
import
DB
from
apps.webui.models.chats
import
Chats
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
Session
import
time
import
uuid
...
...
@@ -14,15 +13,14 @@ import uuid
####################
class
Memory
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
content
=
TextField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Memory
(
Base
):
__tablename__
=
"memory"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
content
=
Column
(
Text
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
MemoryModel
(
BaseModel
):
...
...
@@ -32,6 +30,8 @@ class MemoryModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -39,9 +39,6 @@ class MemoryModel(BaseModel):
class
MemoriesTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Memory
])
def
insert_new_memory
(
self
,
...
...
@@ -59,9 +56,12 @@ class MemoriesTable:
"updated_at"
:
int
(
time
.
time
()),
}
)
result
=
Memory
.
create
(
**
memory
.
model_dump
())
result
=
Memory
(
**
memory
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
m
emory
return
M
emory
Model
.
model_validate
(
result
)
else
:
return
None
...
...
@@ -71,40 +71,38 @@ class MemoriesTable:
content
:
str
,
)
->
Optional
[
MemoryModel
]:
try
:
memory
=
Memory
.
get
(
Memory
.
id
==
id
)
memory
.
content
=
content
memory
.
updated_at
=
int
(
time
.
time
()
)
memory
.
save
()
return
MemoryModel
(
**
model_to_dict
(
memory
)
)
Session
.
query
(
Memory
).
filter_by
(
id
=
id
).
update
(
{
"
content
"
:
content
,
"updated_at"
:
int
(
time
.
time
())}
)
Session
.
commit
()
return
self
.
get_memory_by_id
(
id
)
except
:
return
None
def
get_memories
(
self
)
->
List
[
MemoryModel
]:
try
:
memories
=
Memory
.
select
()
return
[
MemoryModel
(
**
model_
to_dict
(
memory
)
)
for
memory
in
memories
]
memories
=
Session
.
query
(
Memory
).
all
()
return
[
MemoryModel
.
model_
validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memories_by_user_id
(
self
,
user_id
:
str
)
->
List
[
MemoryModel
]:
try
:
memories
=
Memory
.
select
().
wh
er
e
(
Memory
.
user_id
==
user_id
)
return
[
MemoryModel
(
**
model_
to_dict
(
memory
)
)
for
memory
in
memories
]
memories
=
Session
.
qu
er
y
(
Memory
).
filter_by
(
user_id
=
user_id
)
.
all
()
return
[
MemoryModel
.
model_
validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memory_by_id
(
self
,
id
)
->
Optional
[
MemoryModel
]:
def
get_memory_by_id
(
self
,
id
:
str
)
->
Optional
[
MemoryModel
]:
try
:
memory
=
Memory
.
get
(
Memory
.
id
==
id
)
return
MemoryModel
(
**
model_
to_dict
(
memory
)
)
memory
=
Session
.
get
(
Memory
,
id
)
return
MemoryModel
.
model_
validate
(
memory
)
except
:
return
None
def
delete_memory_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Memory
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
...
...
@@ -112,21 +110,17 @@ class MemoriesTable:
def
delete_memories_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
user_id
==
user_id
)
query
.
execute
()
Session
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
def
delete_memory_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
,
Memory
.
user_id
==
user_id
)
query
.
execute
()
Session
.
query
(
Memory
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
Memories
=
MemoriesTable
(
DB
)
Memories
=
MemoriesTable
()
backend/apps/webui/models/models.py
View file @
d0e89a03
...
...
@@ -2,13 +2,10 @@ import json
import
logging
from
typing
import
Optional
import
peewee
as
pw
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
,
Session
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
...
...
@@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
pass
class
Model
(
pw
.
Model
):
id
=
pw
.
TextField
(
unique
=
True
)
class
Model
(
Base
):
__tablename__
=
"model"
id
=
Column
(
Text
,
primary_key
=
True
)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id
=
pw
.
TextField
(
)
user_id
=
Column
(
Text
)
base_model_id
=
pw
.
TextField
(
null
=
True
)
base_model_id
=
Column
(
Text
,
nullable
=
True
)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name
=
pw
.
TextField
(
)
name
=
Column
(
Text
)
"""
The human-readable display name of the model.
"""
params
=
JSONField
(
)
params
=
Column
(
JSONField
)
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta
=
JSONField
(
)
meta
=
Column
(
JSONField
)
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Meta
:
database
=
DB
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
ModelModel
(
BaseModel
):
...
...
@@ -92,6 +88,8 @@ class ModelModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -115,12 +113,6 @@ class ModelForm(BaseModel):
class
ModelsTable
:
def
__init__
(
self
,
db
:
pw
.
SqliteDatabase
|
pw
.
PostgresqlDatabase
,
):
self
.
db
=
db
self
.
db
.
create_tables
([
Model
])
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
...
...
@@ -134,10 +126,13 @@ class ModelsTable:
}
)
try
:
result
=
Model
.
create
(
**
model
.
model_dump
())
result
=
Model
(
**
model
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
model
return
ModelModel
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -145,23 +140,25 @@ class ModelsTable:
return
None
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
return
[
ModelModel
(
**
model_to_dict
(
model
))
for
model
in
Model
.
select
()]
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
Session
.
query
(
Model
).
all
()
]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_
to_dict
(
model
)
)
model
=
Session
.
get
(
Model
,
id
)
return
ModelModel
.
model_
validate
(
model
)
except
:
return
None
def
update_model_by_id
(
self
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
try
:
# update only the fields that are present in the model
query
=
Model
.
update
(
**
model
.
model_dump
()).
wh
er
e
(
Model
.
id
==
id
)
query
.
execute
()
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_
to_dict
(
model
)
)
model
=
Session
.
qu
er
y
(
Model
).
get
(
id
)
model
.
update
(
**
model
.
model_dump
()
)
Session
.
commit
()
Session
.
refresh
(
model
)
return
ModelModel
.
model_
validate
(
model
)
except
Exception
as
e
:
print
(
e
)
...
...
@@ -169,11 +166,10 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Model
.
delete
().
where
(
Model
.
id
==
id
)
query
.
execute
()
Session
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
Models
=
ModelsTable
(
DB
)
Models
=
ModelsTable
()
backend/apps/webui/models/prompts.py
View file @
d0e89a03
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
Session
import
json
...
...
@@ -16,15 +13,14 @@ import json
####################
class
Prompt
(
Model
):
command
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
title
=
TextField
()
content
=
TextField
()
timestamp
=
BigIntegerField
()
class
Prompt
(
Base
):
__tablename__
=
"prompt"
class
Meta
:
database
=
DB
command
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
title
=
Column
(
Text
)
content
=
Column
(
Text
)
timestamp
=
Column
(
BigInteger
)
class
PromptModel
(
BaseModel
):
...
...
@@ -34,6 +30,8 @@ class PromptModel(BaseModel):
content
:
str
timestamp
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -48,10 +46,6 @@ class PromptForm(BaseModel):
class
PromptsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Prompt
])
def
insert_new_prompt
(
self
,
user_id
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
...
...
@@ -66,53 +60,48 @@ class PromptsTable:
)
try
:
result
=
Prompt
.
create
(
**
prompt
.
model_dump
())
result
=
Prompt
(
**
prompt
.
dict
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
p
rompt
return
P
rompt
Model
.
model_validate
(
result
)
else
:
return
None
except
:
except
Exception
as
e
:
return
None
def
get_prompt_by_command
(
self
,
command
:
str
)
->
Optional
[
PromptModel
]:
try
:
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
return
PromptModel
(
**
model_
to_dict
(
prompt
)
)
prompt
=
Session
.
query
(
Prompt
).
filter_by
(
command
=
command
)
.
first
()
return
PromptModel
.
model_
validate
(
prompt
)
except
:
return
None
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
return
[
PromptModel
(
**
model_to_dict
(
prompt
))
for
prompt
in
Prompt
.
select
()
# .limit(limit).offset(skip)
PromptModel
.
model_validate
(
prompt
)
for
prompt
in
Session
.
query
(
Prompt
).
all
()
]
def
update_prompt_by_command
(
self
,
command
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
try
:
query
=
Prompt
.
update
(
title
=
form_data
.
title
,
content
=
form_data
.
content
,
timestamp
=
int
(
time
.
time
()),
).
where
(
Prompt
.
command
==
command
)
query
.
execute
()
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
return
PromptModel
(
**
model_to_dict
(
prompt
))
prompt
=
Session
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
prompt
.
title
=
form_data
.
title
prompt
.
content
=
form_data
.
content
prompt
.
timestamp
=
int
(
time
.
time
())
Session
.
commit
()
return
PromptModel
.
model_validate
(
prompt
)
except
:
return
None
def
delete_prompt_by_command
(
self
,
command
:
str
)
->
bool
:
try
:
query
=
Prompt
.
delete
().
where
((
Prompt
.
command
==
command
))
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Prompt
).
filter_by
(
command
=
command
).
delete
()
return
True
except
:
return
False
Prompts
=
PromptsTable
(
DB
)
Prompts
=
PromptsTable
()
backend/apps/webui/models/tags.py
View file @
d0e89a03
from
pydantic
import
BaseModel
from
typing
import
List
,
Union
,
Optional
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
json
import
uuid
import
time
import
logging
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
Session
from
config
import
SRC_LOG_LEVELS
...
...
@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Tag
(
Model
):
id
=
CharField
(
unique
=
True
)
name
=
CharField
()
user_id
=
CharField
()
data
=
TextField
(
null
=
True
)
class
Tag
(
Base
):
__tablename__
=
"tag"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
)
user_id
=
Column
(
String
)
data
=
Column
(
Text
,
nullable
=
True
)
class
ChatIdTag
(
Model
):
id
=
CharField
(
unique
=
True
)
tag_name
=
CharField
()
chat_id
=
CharField
()
user_id
=
CharField
()
timestamp
=
BigIntegerField
()
class
ChatIdTag
(
Base
):
__tablename__
=
"chatidtag"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
tag_name
=
Column
(
String
)
chat_id
=
Column
(
String
)
user_id
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
TagModel
(
BaseModel
):
...
...
@@ -47,6 +45,8 @@ class TagModel(BaseModel):
user_id
:
str
data
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
class
ChatIdTagModel
(
BaseModel
):
id
:
str
...
...
@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
user_id
:
str
timestamp
:
int
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -75,17 +77,17 @@ class ChatTagsResponse(BaseModel):
class
TagTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
db
.
create_tables
([
Tag
,
ChatIdTag
])
def
insert_new_tag
(
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
id
=
str
(
uuid
.
uuid4
())
tag
=
TagModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"name"
:
name
})
try
:
result
=
Tag
.
create
(
**
tag
.
model_dump
())
result
=
Tag
(
**
tag
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
t
ag
return
T
ag
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -95,8 +97,8 @@ class TagTable:
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
try
:
tag
=
Tag
.
get
(
Tag
.
name
==
name
,
Tag
.
user_id
==
user_id
)
return
TagModel
(
**
model_
to_dict
(
tag
)
)
tag
=
Session
.
query
(
Tag
).
filter
(
name
=
name
,
user_id
=
user_id
)
.
first
()
return
TagModel
.
model_
validate
(
tag
)
except
Exception
as
e
:
return
None
...
...
@@ -118,9 +120,12 @@ class TagTable:
}
)
try
:
result
=
ChatIdTag
.
create
(
**
chatIdTag
.
model_dump
())
result
=
ChatIdTag
(
**
chatIdTag
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
c
hatIdTag
return
C
hatIdTag
Model
.
model_validate
(
result
)
else
:
return
None
except
:
...
...
@@ -128,71 +133,84 @@ class TagTable:
def
get_tags_by_user_id
(
self
,
user_id
:
str
)
->
List
[
TagModel
]:
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
(
ChatIdTag
.
user_id
==
user_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
Session
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
return
[
TagModel
(
**
model_to_dict
(
tag
))
for
tag
in
Tag
.
select
()
.
where
(
Tag
.
user_id
==
user_id
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
TagModel
.
model_validate
(
tag
)
for
tag
in
(
Session
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
def
get_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
List
[
TagModel
]:
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
chat_id
==
chat_id
))
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
Session
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
chat_id
=
chat_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
return
[
TagModel
(
**
model_to_dict
(
tag
))
for
tag
in
Tag
.
select
()
.
where
(
Tag
.
user_id
==
user_id
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
TagModel
.
model_validate
(
tag
)
for
tag
in
(
Session
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
def
get_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
Optional
[
ChatIdTagModel
]:
)
->
List
[
ChatIdTagModel
]:
return
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
))
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
tag_name
==
tag_name
))
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
ChatIdTagModel
.
model_validate
(
chat_id_tag
)
for
chat_id_tag
in
(
Session
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
tag_name
=
tag_name
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
def
count_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
int
:
return
(
ChatIdTag
.
select
(
)
.
where
((
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
Session
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
count
()
)
def
delete_tag_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
res
=
(
Session
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
delete
()
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
Session
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
)
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
return
True
except
Exception
as
e
:
log
.
error
(
f
"delete_tag:
{
e
}
"
)
...
...
@@ -202,21 +220,18 @@ class TagTable:
self
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
chat_id
==
chat
_id
)
&
(
ChatIdTag
.
user_id
==
user_id
)
res
=
(
Session
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
chat_id
=
chat_id
,
user_id
=
user
_id
)
.
delete
(
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
Session
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
)
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
return
True
except
Exception
as
e
:
...
...
@@ -234,4 +249,4 @@ class TagTable:
return
True
Tags
=
TagTable
(
DB
)
Tags
=
TagTable
()
backend/apps/webui/models/tools.py
View file @
d0e89a03
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
JSONField
,
Session
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Tool
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
name
=
TextField
()
content
=
TextField
()
specs
=
JSONField
()
meta
=
JSONField
()
valves
=
JSONField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Tool
(
Base
):
__tablename__
=
"tool"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
name
=
Column
(
Text
)
content
=
Column
(
Text
)
specs
=
Column
(
JSONField
)
meta
=
Column
(
JSONField
)
valves
=
Column
(
JSONField
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
ToolMeta
(
BaseModel
):
...
...
@@ -51,6 +50,8 @@ class ToolModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -78,9 +79,6 @@ class ToolValves(BaseModel):
class
ToolsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Tool
])
def
insert_new_tool
(
self
,
user_id
:
str
,
form_data
:
ToolForm
,
specs
:
List
[
dict
]
...
...
@@ -96,9 +94,12 @@ class ToolsTable:
)
try
:
result
=
Tool
.
create
(
**
tool
.
model_dump
())
result
=
Tool
(
**
tool
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
t
ool
return
T
ool
Model
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
...
...
@@ -107,17 +108,17 @@ class ToolsTable:
def
get_tool_by_id
(
self
,
id
:
str
)
->
Optional
[
ToolModel
]:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolModel
(
**
model_
to_dict
(
tool
)
)
tool
=
Session
.
get
(
Tool
,
id
)
return
ToolModel
.
model_
validate
(
tool
)
except
:
return
None
def
get_tools
(
self
)
->
List
[
ToolModel
]:
return
[
ToolModel
(
**
model_
to_dict
(
tool
)
)
for
tool
in
Tool
.
select
()]
return
[
ToolModel
.
model_
validate
(
tool
)
for
tool
in
Session
.
query
(
Tool
).
all
()]
def
get_tool_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
tool
=
Session
.
get
(
Tool
,
id
)
return
tool
.
valves
if
tool
.
valves
else
{}
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
...
...
@@ -125,14 +126,11 @@ class ToolsTable:
def
update_tool_valves_by_id
(
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
ToolValves
]:
try
:
query
=
Tool
.
update
(
**
{
"valves"
:
valves
},
updated_at
=
int
(
time
.
time
()),
).
where
(
Tool
.
id
==
id
)
query
.
execute
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolValves
(
**
model_to_dict
(
tool
))
Session
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
"valves"
:
valves
,
"updated_at"
:
int
(
time
.
time
())}
)
Session
.
commit
()
return
self
.
get_tool_by_id
(
id
)
except
:
return
None
...
...
@@ -179,25 +177,21 @@ class ToolsTable:
def
update_tool_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
ToolModel
]:
try
:
query
=
Tool
.
update
(
**
updated
,
updated_at
=
int
(
time
.
time
()),
).
where
(
Tool
.
id
==
id
)
query
.
execute
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolModel
(
**
model_to_dict
(
tool
))
tool
=
Session
.
get
(
Tool
,
id
)
tool
.
update
(
**
updated
)
tool
.
updated_at
=
int
(
time
.
time
())
Session
.
commit
()
Session
.
refresh
(
tool
)
return
ToolModel
.
model_validate
(
tool
)
except
:
return
None
def
delete_tool_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Tool
.
delete
().
where
((
Tool
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
Tool
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
Tools
=
ToolsTable
(
DB
)
Tools
=
ToolsTable
()
backend/apps/webui/models/users.py
View file @
d0e89a03
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
,
parse_obj_as
from
typing
import
List
,
Union
,
Optional
import
time
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
utils.misc
import
get_gravatar_url
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
,
Session
from
apps.webui.models.chats
import
Chats
####################
...
...
@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
####################
class
User
(
Model
):
id
=
CharField
(
unique
=
True
)
name
=
CharField
()
email
=
CharField
()
role
=
CharField
()
profile_image_url
=
TextField
()
class
User
(
Base
):
__tablename__
=
"user"
last_active_at
=
BigIntegerField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
id
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
)
email
=
Column
(
String
)
role
=
Column
(
String
)
profile_image_url
=
Column
(
Text
)
api_key
=
CharField
(
null
=
True
,
unique
=
True
)
settings
=
JSONField
(
null
=
True
)
info
=
JSONField
(
null
=
True
)
last_active_at
=
Column
(
BigInteger
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
oauth_sub
=
TextField
(
null
=
True
,
unique
=
True
)
api_key
=
Column
(
String
,
nullable
=
True
,
unique
=
True
)
settings
=
Column
(
JSONField
,
nullable
=
True
)
info
=
Column
(
JSONField
,
nullable
=
True
)
class
Meta
:
database
=
DB
oauth_sub
=
Column
(
Text
,
unique
=
True
)
class
UserSettings
(
BaseModel
):
...
...
@@ -57,6 +57,8 @@ class UserModel(BaseModel):
oauth_sub
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
class
UsersTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
User
])
def
insert_new_user
(
self
,
...
...
@@ -102,7 +101,10 @@ class UsersTable:
"oauth_sub"
:
oauth_sub
,
}
)
result
=
User
.
create
(
**
user
.
model_dump
())
result
=
User
(
**
user
.
model_dump
())
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
user
else
:
...
...
@@ -110,56 +112,57 @@ class UsersTable:
def
get_user_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
except
:
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
Exception
as
e
:
return
None
def
get_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
api_key
==
api_key
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
Session
.
query
(
User
).
filter_by
(
api_key
=
api_key
)
.
first
()
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
email
==
email
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
Session
.
query
(
User
).
filter_by
(
email
=
email
)
.
first
()
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
get_user_by_oauth_sub
(
self
,
sub
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
oauth_sub
==
sub
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
Session
.
query
(
User
).
filter_by
(
oauth_sub
=
sub
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
get_users
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
return
[
UserModel
(
**
model_to_dict
(
user
))
for
user
in
User
.
select
()
# .limit(limit).offset(skip)
]
users
=
(
Session
.
query
(
User
)
# .offset(skip).limit(limit)
.
all
()
)
return
[
UserModel
.
model_validate
(
user
)
for
user
in
users
]
def
get_num_users
(
self
)
->
Optional
[
int
]:
return
User
.
select
(
).
count
()
return
Session
.
query
(
User
).
count
()
def
get_first_user
(
self
)
->
UserModel
:
try
:
user
=
User
.
select
(
).
order_by
(
User
.
created_at
).
first
()
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
Session
.
query
(
User
).
order_by
(
User
.
created_at
).
first
()
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_role_by_id
(
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
role
=
role
).
where
(
User
.
id
==
id
)
query
.
execute
()
Session
.
query
(
User
).
filter_by
(
id
=
id
)
.
update
(
{
"
role
"
:
role
}
)
Session
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
...
...
@@ -167,23 +170,25 @@ class UsersTable:
self
,
id
:
str
,
profile_image_url
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
profile_image_url
=
profile_image_url
).
wher
e
(
User
.
id
==
id
Session
.
query
(
User
).
filter_by
(
id
=
id
).
updat
e
(
{
"profile_image_url"
:
profile_image_url
}
)
query
.
execute
()
Session
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_last_active_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
last_active_at
=
int
(
time
.
time
())).
where
(
User
.
id
==
id
)
query
.
execute
()
Session
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"last_active_at"
:
int
(
time
.
time
())}
)
Session
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
...
...
@@ -191,22 +196,22 @@ class UsersTable:
self
,
id
:
str
,
oauth_sub
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
oauth_sub
=
oauth_sub
).
where
(
User
.
id
==
id
)
query
.
execute
()
Session
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"oauth_sub"
:
oauth_sub
})
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
**
updated
)
.
where
(
User
.
id
==
id
)
query
.
execute
()
Session
.
query
(
User
).
filter_by
(
id
=
id
)
.
update
(
updated
)
Session
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
# return UserModel(**user.dict())
except
Exception
as
e
:
return
None
def
delete_user_by_id
(
self
,
id
:
str
)
->
bool
:
...
...
@@ -216,8 +221,8 @@ class UsersTable:
if
result
:
# Delete User
query
=
User
.
delete
().
where
(
User
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
Session
.
query
(
User
).
filter_by
(
id
=
id
).
delete
(
)
Session
.
commit
()
return
True
else
:
...
...
@@ -227,19 +232,18 @@ class UsersTable:
def
update_user_api_key_by_id
(
self
,
id
:
str
,
api_key
:
str
)
->
str
:
try
:
query
=
User
.
update
(
api_key
=
api_key
).
where
(
User
.
id
==
id
)
result
=
query
.
execute
()
result
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"api_key"
:
api_key
})
Session
.
commit
()
return
True
if
result
==
1
else
False
except
:
return
False
def
get_user_api_key_by_id
(
self
,
id
:
str
)
->
Optional
[
str
]:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
user
.
api_key
except
:
except
Exception
as
e
:
return
None
Users
=
UsersTable
(
DB
)
Users
=
UsersTable
()
backend/apps/webui/routers/chats.py
View file @
d0e89a03
...
...
@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@
router
.
get
(
"/list/user/{user_id}"
,
response_model
=
List
[
ChatTitleIdResponse
])
async
def
get_user_chat_list_by_user_id
(
user_id
:
str
,
user
=
Depends
(
get_admin_user
),
skip
:
int
=
0
,
limit
:
int
=
50
user_id
:
str
,
user
=
Depends
(
get_admin_user
),
skip
:
int
=
0
,
limit
:
int
=
50
,
):
return
Chats
.
get_chat_list_by_user_id
(
user_id
,
include_archived
=
True
,
skip
=
skip
,
limit
=
limit
...
...
@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
@
router
.
get
(
"/all/archived"
,
response_model
=
List
[
ChatResponse
])
async
def
get_user_chats
(
user
=
Depends
(
get_verified_user
)):
async
def
get_user_
archived_
chats
(
user
=
Depends
(
get_verified_user
)):
return
[
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
for
chat
in
Chats
.
get_archived_chats_by_user_id
(
user
.
id
)
...
...
backend/apps/webui/routers/documents.py
View file @
d0e89a03
...
...
@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
@
router
.
post
(
"/doc/update"
,
response_model
=
Optional
[
DocumentResponse
])
async
def
update_doc_by_name
(
name
:
str
,
form_data
:
DocumentUpdateForm
,
user
=
Depends
(
get_admin_user
)
name
:
str
,
form_data
:
DocumentUpdateForm
,
user
=
Depends
(
get_admin_user
),
):
doc
=
Documents
.
update_doc_by_name
(
name
,
form_data
)
if
doc
:
...
...
backend/apps/webui/routers/files.py
View file @
d0e89a03
...
...
@@ -50,10 +50,7 @@ router = APIRouter()
@
router
.
post
(
"/"
)
def
upload_file
(
file
:
UploadFile
=
File
(...),
user
=
Depends
(
get_verified_user
),
):
def
upload_file
(
file
:
UploadFile
=
File
(...),
user
=
Depends
(
get_verified_user
)):
log
.
info
(
f
"file.content_type:
{
file
.
content_type
}
"
)
try
:
unsanitized_filename
=
file
.
filename
...
...
backend/apps/webui/routers/memories.py
View file @
d0e89a03
...
...
@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel):
@
router
.
post
(
"/add"
,
response_model
=
Optional
[
MemoryModel
])
async
def
add_memory
(
request
:
Request
,
form_data
:
AddMemoryForm
,
user
=
Depends
(
get_verified_user
)
request
:
Request
,
form_data
:
AddMemoryForm
,
user
=
Depends
(
get_verified_user
),
):
memory
=
Memories
.
insert_new_memory
(
user
.
id
,
form_data
.
content
)
memory_embedding
=
request
.
app
.
state
.
EMBEDDING_FUNCTION
(
memory
.
content
)
...
...
backend/apps/webui/routers/models.py
View file @
d0e89a03
...
...
@@ -5,6 +5,7 @@ from typing import List, Union, Optional
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.webui.models.models
import
Models
,
ModelModel
,
ModelForm
,
ModelResponse
from
utils.utils
import
get_verified_user
,
get_admin_user
...
...
@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)):
@
router
.
post
(
"/add"
,
response_model
=
Optional
[
ModelModel
])
async
def
add_new_model
(
request
:
Request
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
)
request
:
Request
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
),
):
if
form_data
.
id
in
request
.
app
.
state
.
MODELS
:
raise
HTTPException
(
...
...
@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@
router
.
post
(
"/update"
,
response_model
=
Optional
[
ModelModel
])
async
def
update_model_by_id
(
request
:
Request
,
id
:
str
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
)
request
:
Request
,
id
:
str
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
),
):
model
=
Models
.
get_model_by_id
(
id
)
if
model
:
...
...
backend/apps/webui/routers/prompts.py
View file @
d0e89a03
...
...
@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
@
router
.
post
(
"/command/{command}/update"
,
response_model
=
Optional
[
PromptModel
])
async
def
update_prompt_by_command
(
command
:
str
,
form_data
:
PromptForm
,
user
=
Depends
(
get_admin_user
)
command
:
str
,
form_data
:
PromptForm
,
user
=
Depends
(
get_admin_user
),
):
prompt
=
Prompts
.
update_prompt_by_command
(
f
"/
{
command
}
"
,
form_data
)
if
prompt
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment