Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
da403f3e
Commit
da403f3e
authored
Jun 24, 2024
by
Jonathan Rohde
Browse files
feat(sqlalchemy): use session factory instead of context manager
parent
eb01e8d2
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
654 additions
and
773 deletions
+654
-773
backend/apps/webui/internal/db.py
backend/apps/webui/internal/db.py
+1
-11
backend/apps/webui/models/auths.py
backend/apps/webui/models/auths.py
+61
-70
backend/apps/webui/models/chats.py
backend/apps/webui/models/chats.py
+139
-161
backend/apps/webui/models/documents.py
backend/apps/webui/models/documents.py
+36
-43
backend/apps/webui/models/files.py
backend/apps/webui/models/files.py
+14
-22
backend/apps/webui/models/functions.py
backend/apps/webui/models/functions.py
+53
-64
backend/apps/webui/models/memories.py
backend/apps/webui/models/memories.py
+25
-35
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+19
-25
backend/apps/webui/models/prompts.py
backend/apps/webui/models/prompts.py
+43
-50
backend/apps/webui/models/tags.py
backend/apps/webui/models/tags.py
+99
-109
backend/apps/webui/models/tools.py
backend/apps/webui/models/tools.py
+26
-33
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+117
-134
backend/main.py
backend/main.py
+10
-4
backend/test/apps/webui/routers/test_chats.py
backend/test/apps/webui/routers/test_chats.py
+2
-0
backend/test/util/abstract_integration_test.py
backend/test/util/abstract_integration_test.py
+9
-12
No files found.
backend/apps/webui/internal/db.py
View file @
da403f3e
...
...
@@ -57,14 +57,4 @@ SessionLocal = sessionmaker(
autocommit
=
False
,
autoflush
=
False
,
bind
=
engine
,
expire_on_commit
=
False
)
Base
=
declarative_base
()
@
contextmanager
def
get_session
():
session
=
scoped_session
(
SessionLocal
)
try
:
yield
session
session
.
commit
()
except
Exception
as
e
:
session
.
rollback
()
raise
e
Session
=
scoped_session
(
SessionLocal
)
backend/apps/webui/models/auths.py
View file @
da403f3e
...
...
@@ -3,12 +3,11 @@ from typing import Optional
import
uuid
import
logging
from
sqlalchemy
import
String
,
Column
,
Boolean
from
sqlalchemy.orm
import
Session
from
apps.webui.models.users
import
UserModel
,
Users
from
utils.utils
import
verify_password
from
apps.webui.internal.db
import
Base
,
get_s
ession
from
apps.webui.internal.db
import
Base
,
S
ession
from
config
import
SRC_LOG_LEVELS
...
...
@@ -103,7 +102,6 @@ class AuthsTable:
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
log
.
info
(
"insert_new_auth"
)
id
=
str
(
uuid
.
uuid4
())
...
...
@@ -112,14 +110,13 @@ class AuthsTable:
**
{
"id"
:
id
,
"email"
:
email
,
"password"
:
password
,
"active"
:
True
}
)
result
=
Auth
(
**
auth
.
model_dump
())
db
.
add
(
result
)
Session
.
add
(
result
)
user
=
Users
.
insert_new_user
(
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
)
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
and
user
:
return
user
...
...
@@ -128,9 +125,8 @@ class AuthsTable:
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user:
{
email
}
"
)
with
get_session
()
as
db
:
try
:
auth
=
db
.
query
(
Auth
).
filter_by
(
email
=
email
,
active
=
True
).
first
()
auth
=
Session
.
query
(
Auth
).
filter_by
(
email
=
email
,
active
=
True
).
first
()
if
auth
:
if
verify_password
(
password
,
auth
.
password
):
user
=
Users
.
get_user_by_id
(
auth
.
id
)
...
...
@@ -144,7 +140,6 @@ class AuthsTable:
def
authenticate_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user_by_api_key:
{
api_key
}
"
)
with
get_session
()
as
db
:
# if no api_key, return None
if
not
api_key
:
return
None
...
...
@@ -157,9 +152,8 @@ class AuthsTable:
def
authenticate_user_by_trusted_header
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user_by_trusted_header:
{
email
}
"
)
with
get_session
()
as
db
:
try
:
auth
=
db
.
query
(
Auth
).
filter
(
email
=
email
,
active
=
True
).
first
()
auth
=
Session
.
query
(
Auth
).
filter
(
email
=
email
,
active
=
True
).
first
()
if
auth
:
user
=
Users
.
get_user_by_id
(
auth
.
id
)
return
user
...
...
@@ -167,31 +161,28 @@ class AuthsTable:
return
None
def
update_user_password_by_id
(
self
,
id
:
str
,
new_password
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
result
=
(
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"password"
:
new_password
})
Session
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"password"
:
new_password
})
)
return
True
if
result
==
1
else
False
except
:
return
False
def
update_email_by_id
(
self
,
id
:
str
,
email
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
result
=
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"email"
:
email
})
result
=
Session
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"email"
:
email
})
return
True
if
result
==
1
else
False
except
:
return
False
def
delete_auth_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
# Delete User
result
=
Users
.
delete_user_by_id
(
id
)
if
result
:
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
delete
()
Session
.
query
(
Auth
).
filter_by
(
id
=
id
).
delete
()
return
True
else
:
...
...
backend/apps/webui/models/chats.py
View file @
da403f3e
...
...
@@ -6,9 +6,8 @@ import uuid
import
time
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Boolean
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
get_s
ession
from
apps.webui.internal.db
import
Base
,
S
ession
####################
...
...
@@ -80,7 +79,6 @@ class ChatTitleIdResponse(BaseModel):
class
ChatTable
:
def
insert_new_chat
(
self
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
with
get_session
()
as
db
:
id
=
str
(
uuid
.
uuid4
())
chat
=
ChatModel
(
**
{
...
...
@@ -98,29 +96,27 @@ class ChatTable:
)
result
=
Chat
(
**
chat
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
return
ChatModel
.
model_validate
(
result
)
if
result
else
None
def
update_chat_by_id
(
self
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
with
get_session
()
as
db
:
try
:
chat_obj
=
db
.
get
(
Chat
,
id
)
chat_obj
=
Session
.
get
(
Chat
,
id
)
chat_obj
.
chat
=
json
.
dumps
(
chat
)
chat_obj
.
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
chat_obj
.
updated_at
=
int
(
time
.
time
())
db
.
commit
()
db
.
refresh
(
chat_obj
)
Session
.
commit
()
Session
.
refresh
(
chat_obj
)
return
ChatModel
.
model_validate
(
chat_obj
)
except
Exception
as
e
:
return
None
def
insert_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
with
get_session
()
as
db
:
# Get the existing chat to share
chat
=
db
.
get
(
Chat
,
chat_id
)
chat
=
Session
.
get
(
Chat
,
chat_id
)
# Check if the chat is already shared
if
chat
.
share_id
:
return
self
.
get_chat_by_id_and_user_id
(
chat
.
share_id
,
"shared"
)
...
...
@@ -136,12 +132,12 @@ class ChatTable:
}
)
shared_result
=
Chat
(
**
shared_chat
.
model_dump
())
db
.
add
(
shared_result
)
db
.
commit
()
db
.
refresh
(
shared_result
)
Session
.
add
(
shared_result
)
Session
.
commit
()
Session
.
refresh
(
shared_result
)
# Update the original chat with the share_id
result
=
(
db
.
query
(
Chat
)
Session
.
query
(
Chat
)
.
filter_by
(
id
=
chat_id
)
.
update
({
"share_id"
:
shared_chat
.
id
})
)
...
...
@@ -149,15 +145,14 @@ class ChatTable:
return
shared_chat
if
(
shared_result
and
result
)
else
None
def
update_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
with
get_session
()
as
db
:
try
:
print
(
"update_shared_chat_by_id"
)
chat
=
db
.
get
(
Chat
,
chat_id
)
chat
=
Session
.
get
(
Chat
,
chat_id
)
print
(
chat
)
chat
.
title
=
chat
.
title
chat
.
chat
=
chat
.
chat
db
.
commit
()
db
.
refresh
(
chat
)
Session
.
commit
()
Session
.
refresh
(
chat
)
return
self
.
get_chat_by_id
(
chat
.
share_id
)
except
:
...
...
@@ -165,8 +160,7 @@ class ChatTable:
def
delete_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
user_id
=
f
"shared-
{
chat_id
}
"
).
delete
()
Session
.
query
(
Chat
).
filter_by
(
user_id
=
f
"shared-
{
chat_id
}
"
).
delete
()
return
True
except
:
return
False
...
...
@@ -175,30 +169,27 @@ class ChatTable:
self
,
id
:
str
,
share_id
:
Optional
[
str
]
)
->
Optional
[
ChatModel
]:
try
:
with
get_session
()
as
db
:
chat
=
db
.
get
(
Chat
,
id
)
chat
=
Session
.
get
(
Chat
,
id
)
chat
.
share_id
=
share_id
db
.
commit
()
db
.
refresh
(
chat
)
return
chat
Session
.
commit
()
Session
.
refresh
(
chat
)
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
toggle_chat_archive_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
with
get_session
()
as
db
:
chat
=
self
.
get_chat_by_id
(
id
)
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
update
({
"archived"
:
not
chat
.
archived
}
)
return
self
.
get_chat_by_id
(
id
)
chat
=
Session
.
get
(
Chat
,
id
)
chat
.
archived
=
not
chat
.
archived
Session
.
commit
(
)
Session
.
refresh
(
chat
)
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
archive_all_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
update
({
"archived"
:
True
})
Session
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
update
({
"archived"
:
True
})
return
True
except
:
return
False
...
...
@@ -206,9 +197,8 @@ class ChatTable:
def
get_archived_chat_list_by_user_id
(
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
Session
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
# .limit(limit).offset(skip)
...
...
@@ -223,8 +213,7 @@ class ChatTable:
skip
:
int
=
0
,
limit
:
int
=
50
,
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
query
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
)
query
=
Session
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
)
if
not
include_archived
:
query
=
query
.
filter_by
(
archived
=
False
)
all_chats
=
(
...
...
@@ -237,9 +226,8 @@ class ChatTable:
def
get_chat_list_by_chat_ids
(
self
,
chat_ids
:
List
[
str
],
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
Session
.
query
(
Chat
)
.
filter
(
Chat
.
id
.
in_
(
chat_ids
))
.
filter_by
(
archived
=
False
)
.
order_by
(
Chat
.
updated_at
.
desc
())
...
...
@@ -249,16 +237,14 @@ class ChatTable:
def
get_chat_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
with
get_session
()
as
db
:
chat
=
db
.
get
(
Chat
,
id
)
chat
=
Session
.
get
(
Chat
,
id
)
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
get_chat_by_share_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
with
get_session
()
as
db
:
chat
=
db
.
query
(
Chat
).
filter_by
(
share_id
=
id
).
first
()
chat
=
Session
.
query
(
Chat
).
filter_by
(
share_id
=
id
).
first
()
if
chat
:
return
self
.
get_chat_by_id
(
id
)
...
...
@@ -269,34 +255,30 @@ class ChatTable:
def
get_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
with
get_session
()
as
db
:
chat
=
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
first
()
chat
=
Session
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
first
()
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
get_chats
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
Session
.
query
(
Chat
)
# .limit(limit).offset(skip)
.
order_by
(
Chat
.
updated_at
.
desc
())
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
Session
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
Chat
.
updated_at
.
desc
())
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_archived_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
Session
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
.
order_by
(
Chat
.
updated_at
.
desc
())
)
...
...
@@ -304,8 +286,7 @@ class ChatTable:
def
delete_chat_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
delete
()
Session
.
query
(
Chat
).
filter_by
(
id
=
id
).
delete
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
...
...
@@ -313,8 +294,7 @@ class ChatTable:
def
delete_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
Session
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
...
...
@@ -322,21 +302,19 @@ class ChatTable:
def
delete_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
self
.
delete_shared_chats_by_user_id
(
user_id
)
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
delete
()
Session
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
def
delete_shared_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
chats_by_user
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
all
()
chats_by_user
=
Session
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
all
()
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
chats_by_user
]
db
.
query
(
Chat
).
filter
(
Chat
.
user_id
.
in_
(
shared_chat_ids
)).
delete
()
Session
.
query
(
Chat
).
filter
(
Chat
.
user_id
.
in_
(
shared_chat_ids
)).
delete
()
return
True
except
:
...
...
backend/apps/webui/models/documents.py
View file @
da403f3e
...
...
@@ -4,9 +4,8 @@ import time
import
logging
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
get_s
ession
from
apps.webui.internal.db
import
Base
,
S
ession
import
json
...
...
@@ -84,11 +83,10 @@ class DocumentsTable:
)
try
:
with
get_session
()
as
db
:
result
=
Document
(
**
document
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
DocumentModel
.
model_validate
(
result
)
else
:
...
...
@@ -98,31 +96,28 @@ class DocumentsTable:
def
get_doc_by_name
(
self
,
name
:
str
)
->
Optional
[
DocumentModel
]:
try
:
with
get_session
()
as
db
:
document
=
db
.
query
(
Document
).
filter_by
(
name
=
name
).
first
()
document
=
Session
.
query
(
Document
).
filter_by
(
name
=
name
).
first
()
return
DocumentModel
.
model_validate
(
document
)
if
document
else
None
except
:
return
None
def
get_docs
(
self
)
->
List
[
DocumentModel
]:
with
get_session
()
as
db
:
return
[
DocumentModel
.
model_validate
(
doc
)
for
doc
in
db
.
query
(
Document
).
all
()
DocumentModel
.
model_validate
(
doc
)
for
doc
in
Session
.
query
(
Document
).
all
()
]
def
update_doc_by_name
(
self
,
name
:
str
,
form_data
:
DocumentUpdateForm
)
->
Optional
[
DocumentModel
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
Session
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
{
"title"
:
form_data
.
title
,
"name"
:
form_data
.
name
,
"timestamp"
:
int
(
time
.
time
()),
}
)
db
.
commit
()
Session
.
commit
()
return
self
.
get_doc_by_name
(
form_data
.
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
...
...
@@ -132,18 +127,17 @@ class DocumentsTable:
self
,
name
:
str
,
updated
:
dict
)
->
Optional
[
DocumentModel
]:
try
:
with
get_session
()
as
db
:
doc
=
self
.
get_doc_by_name
(
name
)
doc_content
=
json
.
loads
(
doc
.
content
if
doc
.
content
else
"{}"
)
doc_content
=
{
**
doc_content
,
**
updated
}
db
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
Session
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
{
"content"
:
json
.
dumps
(
doc_content
),
"timestamp"
:
int
(
time
.
time
()),
}
)
db
.
commit
()
Session
.
commit
()
return
self
.
get_doc_by_name
(
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
...
...
@@ -151,8 +145,7 @@ class DocumentsTable:
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Document
).
filter_by
(
name
=
name
).
delete
()
Session
.
query
(
Document
).
filter_by
(
name
=
name
).
delete
()
return
True
except
:
return
False
...
...
backend/apps/webui/models/files.py
View file @
da403f3e
...
...
@@ -4,9 +4,8 @@ import time
import
logging
from
sqlalchemy
import
Column
,
String
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
JSONField
,
Base
,
get_s
ession
from
apps.webui.internal.db
import
JSONField
,
Base
,
S
ession
import
json
...
...
@@ -71,11 +70,10 @@ class FilesTable:
)
try
:
with
get_session
()
as
db
:
result
=
File
(
**
file
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
FileModel
.
model_validate
(
result
)
else
:
...
...
@@ -86,30 +84,24 @@ class FilesTable:
def
get_file_by_id
(
self
,
id
:
str
)
->
Optional
[
FileModel
]:
try
:
with
get_session
()
as
db
:
file
=
db
.
get
(
File
,
id
)
file
=
Session
.
get
(
File
,
id
)
return
FileModel
.
model_validate
(
file
)
except
:
return
None
def
get_files
(
self
)
->
List
[
FileModel
]:
with
get_session
()
as
db
:
return
[
FileModel
.
model_validate
(
file
)
for
file
in
db
.
query
(
File
).
all
()]
return
[
FileModel
.
model_validate
(
file
)
for
file
in
Session
.
query
(
File
).
all
()]
def
delete_file_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
File
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
Session
.
query
(
File
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
def
delete_all_files
(
self
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
File
).
delete
()
db
.
commit
()
Session
.
query
(
File
).
delete
()
return
True
except
:
return
False
...
...
backend/apps/webui/models/functions.py
View file @
da403f3e
...
...
@@ -4,9 +4,8 @@ import time
import
logging
from
sqlalchemy
import
Column
,
String
,
Text
,
BigInteger
,
Boolean
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
JSONField
,
Base
,
get_s
ession
from
apps.webui.internal.db
import
JSONField
,
Base
,
S
ession
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -100,11 +99,10 @@ class FunctionsTable:
)
try
:
with
get_session
()
as
db
:
result
=
Function
(
**
function
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
FunctionModel
.
model_validate
(
result
)
else
:
...
...
@@ -115,48 +113,42 @@ class FunctionsTable:
def
get_function_by_id
(
self
,
id
:
str
)
->
Optional
[
FunctionModel
]:
try
:
with
get_session
()
as
db
:
function
=
db
.
get
(
Function
,
id
)
function
=
Session
.
get
(
Function
,
id
)
return
FunctionModel
.
model_validate
(
function
)
except
:
return
None
def
get_functions
(
self
,
active_only
=
False
)
->
List
[
FunctionModel
]:
if
active_only
:
with
get_session
()
as
db
:
return
[
FunctionModel
.
model_validate
(
function
)
for
function
in
db
.
query
(
Function
).
filter_by
(
is_active
=
True
).
all
()
for
function
in
Session
.
query
(
Function
).
filter_by
(
is_active
=
True
).
all
()
]
else
:
with
get_session
()
as
db
:
return
[
FunctionModel
.
model_validate
(
function
)
for
function
in
db
.
query
(
Function
).
all
()
for
function
in
Session
.
query
(
Function
).
all
()
]
def
get_functions_by_type
(
self
,
type
:
str
,
active_only
=
False
)
->
List
[
FunctionModel
]:
if
active_only
:
with
get_session
()
as
db
:
return
[
FunctionModel
.
model_validate
(
function
)
for
function
in
db
.
query
(
Function
)
for
function
in
Session
.
query
(
Function
)
.
filter_by
(
type
=
type
,
is_active
=
True
)
.
all
()
]
else
:
with
get_session
()
as
db
:
return
[
FunctionModel
.
model_validate
(
function
)
for
function
in
db
.
query
(
Function
).
filter_by
(
type
=
type
).
all
()
for
function
in
Session
.
query
(
Function
).
filter_by
(
type
=
type
).
all
()
]
def
get_function_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
with
get_session
()
as
db
:
function
=
db
.
get
(
Function
,
id
)
function
=
Session
.
get
(
Function
,
id
)
return
function
.
valves
if
function
.
valves
else
{}
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
...
...
@@ -166,11 +158,11 @@ class FunctionsTable:
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
FunctionValves
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Function
).
filter_by
(
id
=
id
).
update
(
{
"valves"
:
valves
,
"
updated_at
"
:
int
(
time
.
time
())
}
)
db
.
commit
(
)
function
=
Session
.
get
(
Function
,
id
)
function
.
valves
=
valves
function
.
updated_at
=
int
(
time
.
time
())
Session
.
commit
(
)
Session
.
refresh
(
function
)
return
self
.
get_function_by_id
(
id
)
except
:
return
None
...
...
@@ -219,36 +211,33 @@ class FunctionsTable:
def
update_function_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
FunctionModel
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Function
).
filter_by
(
id
=
id
).
update
(
Session
.
query
(
Function
).
filter_by
(
id
=
id
).
update
(
{
**
updated
,
"updated_at"
:
int
(
time
.
time
()),
}
)
db
.
commit
()
Session
.
commit
()
return
self
.
get_function_by_id
(
id
)
except
:
return
None
def
deactivate_all_functions
(
self
)
->
Optional
[
bool
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Function
).
update
(
Session
.
query
(
Function
).
update
(
{
"is_active"
:
False
,
"updated_at"
:
int
(
time
.
time
()),
}
)
db
.
commit
()
Session
.
commit
()
return
True
except
:
return
None
def
delete_function_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Function
).
filter_by
(
id
=
id
).
delete
()
Session
.
query
(
Function
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
...
...
backend/apps/webui/models/memories.py
View file @
da403f3e
...
...
@@ -2,10 +2,8 @@ from pydantic import BaseModel, ConfigDict
from
typing
import
List
,
Union
,
Optional
from
sqlalchemy
import
Column
,
String
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
get_session
from
apps.webui.models.chats
import
Chats
from
apps.webui.internal.db
import
Base
,
Session
import
time
import
uuid
...
...
@@ -58,11 +56,10 @@ class MemoriesTable:
"updated_at"
:
int
(
time
.
time
()),
}
)
with
get_session
()
as
db
:
result
=
Memory
(
**
memory
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
MemoryModel
.
model_validate
(
result
)
else
:
...
...
@@ -74,62 +71,55 @@ class MemoriesTable:
content
:
str
,
)
->
Optional
[
MemoryModel
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
update
(
Session
.
query
(
Memory
).
filter_by
(
id
=
id
).
update
(
{
"content"
:
content
,
"updated_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
Session
.
commit
()
return
self
.
get_memory_by_id
(
id
)
except
:
return
None
def
get_memories
(
self
)
->
List
[
MemoryModel
]:
try
:
with
get_session
()
as
db
:
memories
=
db
.
query
(
Memory
).
all
()
memories
=
Session
.
query
(
Memory
).
all
()
return
[
MemoryModel
.
model_validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memories_by_user_id
(
self
,
user_id
:
str
)
->
List
[
MemoryModel
]:
try
:
with
get_session
()
as
db
:
memories
=
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
).
all
()
memories
=
Session
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
).
all
()
return
[
MemoryModel
.
model_validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memory_by_id
(
self
,
id
:
str
)
->
Optional
[
MemoryModel
]:
try
:
with
get_session
()
as
db
:
memory
=
db
.
get
(
Memory
,
id
)
memory
=
Session
.
get
(
Memory
,
id
)
return
MemoryModel
.
model_validate
(
memory
)
except
:
return
None
def
delete_memory_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
delete
()
Session
.
query
(
Memory
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
def
delete_memories_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
bool
:
def
delete_memories_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
).
delete
()
Session
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
def
delete_memory_by_id_and_user_id
(
self
,
db
:
Session
,
id
:
str
,
user_id
:
str
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Memory
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
Session
.
query
(
Memory
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
...
...
backend/apps/webui/models/models.py
View file @
da403f3e
...
...
@@ -4,9 +4,8 @@ from typing import Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_s
ession
from
apps.webui.internal.db
import
Base
,
JSONField
,
S
ession
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
...
...
@@ -127,11 +126,10 @@ class ModelsTable:
}
)
try
:
with
get_session
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
ModelModel
.
model_validate
(
result
)
...
...
@@ -142,13 +140,11 @@ class ModelsTable:
return
None
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
with
get_session
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
Session
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
with
get_session
()
as
db
:
model
=
db
.
get
(
Model
,
id
)
model
=
Session
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
except
:
return
None
...
...
@@ -156,11 +152,10 @@ class ModelsTable:
def
update_model_by_id
(
self
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
try
:
# update only the fields that are present in the model
with
get_session
()
as
db
:
model
=
db
.
query
(
Model
).
get
(
id
)
model
=
Session
.
query
(
Model
).
get
(
id
)
model
.
update
(
**
model
.
model_dump
())
db
.
commit
()
db
.
refresh
(
model
)
Session
.
commit
()
Session
.
refresh
(
model
)
return
ModelModel
.
model_validate
(
model
)
except
Exception
as
e
:
print
(
e
)
...
...
@@ -169,8 +164,7 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
Session
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
...
...
backend/apps/webui/models/prompts.py
View file @
da403f3e
...
...
@@ -3,9 +3,8 @@ from typing import List, Optional
import
time
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
get_s
ession
from
apps.webui.internal.db
import
Base
,
S
ession
import
json
...
...
@@ -50,7 +49,6 @@ class PromptsTable:
def
insert_new_prompt
(
self
,
user_id
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
with
get_session
()
as
db
:
prompt
=
PromptModel
(
**
{
"user_id"
:
user_id
,
...
...
@@ -63,9 +61,9 @@ class PromptsTable:
try
:
result
=
Prompt
(
**
prompt
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
PromptModel
.
model_validate
(
result
)
else
:
...
...
@@ -74,38 +72,33 @@ class PromptsTable:
return
None
def
get_prompt_by_command
(
self
,
command
:
str
)
->
Optional
[
PromptModel
]:
with
get_session
()
as
db
:
try
:
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
prompt
=
Session
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
return
PromptModel
.
model_validate
(
prompt
)
except
:
return
None
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
with
get_session
()
as
db
:
return
[
PromptModel
.
model_validate
(
prompt
)
for
prompt
in
db
.
query
(
Prompt
).
all
()
PromptModel
.
model_validate
(
prompt
)
for
prompt
in
Session
.
query
(
Prompt
).
all
()
]
def
update_prompt_by_command
(
self
,
command
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
with
get_session
()
as
db
:
try
:
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
prompt
=
Session
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
prompt
.
title
=
form_data
.
title
prompt
.
content
=
form_data
.
content
prompt
.
timestamp
=
int
(
time
.
time
())
db
.
commit
()
return
prompt
# return self.get_prompt_by_command(command)
Session
.
commit
()
return
PromptModel
.
model_validate
(
prompt
)
except
:
return
None
def
delete_prompt_by_command
(
self
,
command
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
delete
()
Session
.
query
(
Prompt
).
filter_by
(
command
=
command
).
delete
()
return
True
except
:
return
False
...
...
backend/apps/webui/models/tags.py
View file @
da403f3e
...
...
@@ -7,9 +7,8 @@ import time
import
logging
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
get_s
ession
from
apps.webui.internal.db
import
Base
,
S
ession
from
config
import
SRC_LOG_LEVELS
...
...
@@ -83,11 +82,10 @@ class TagTable:
id
=
str
(
uuid
.
uuid4
())
tag
=
TagModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"name"
:
name
})
try
:
with
get_session
()
as
db
:
result
=
Tag
(
**
tag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
TagModel
.
model_validate
(
result
)
else
:
...
...
@@ -99,8 +97,7 @@ class TagTable:
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
try
:
with
get_session
()
as
db
:
tag
=
db
.
query
(
Tag
).
filter
(
name
=
name
,
user_id
=
user_id
).
first
()
tag
=
Session
.
query
(
Tag
).
filter
(
name
=
name
,
user_id
=
user_id
).
first
()
return
TagModel
.
model_validate
(
tag
)
except
Exception
as
e
:
return
None
...
...
@@ -123,11 +120,10 @@ class TagTable:
}
)
try
:
with
get_session
()
as
db
:
result
=
ChatIdTag
(
**
chatIdTag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
ChatIdTagModel
.
model_validate
(
result
)
else
:
...
...
@@ -136,11 +132,10 @@ class TagTable:
return
None
def
get_tags_by_user_id
(
self
,
user_id
:
str
)
->
List
[
TagModel
]:
with
get_session
()
as
db
:
tag_names
=
[
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
Session
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
...
...
@@ -150,7 +145,7 @@ class TagTable:
return
[
TagModel
.
model_validate
(
tag
)
for
tag
in
(
db
.
query
(
Tag
)
Session
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
...
...
@@ -160,11 +155,10 @@ class TagTable:
def
get_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
List
[
TagModel
]:
with
get_session
()
as
db
:
tag_names
=
[
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
Session
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
chat_id
=
chat_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
...
...
@@ -174,7 +168,7 @@ class TagTable:
return
[
TagModel
.
model_validate
(
tag
)
for
tag
in
(
db
.
query
(
Tag
)
Session
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
...
...
@@ -184,11 +178,10 @@ class TagTable:
def
get_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
List
[
ChatIdTagModel
]:
with
get_session
()
as
db
:
return
[
ChatIdTagModel
.
model_validate
(
chat_id_tag
)
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
Session
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
tag_name
=
tag_name
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
...
...
@@ -198,30 +191,28 @@ class TagTable:
def
count_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
int
:
with
get_session
()
as
db
:
return
(
db
.
query
(
ChatIdTag
)
Session
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
count
()
)
def
delete_tag_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
res
=
(
db
.
query
(
ChatIdTag
)
Session
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
delete
()
)
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
Session
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
Session
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
return
True
except
Exception
as
e
:
log
.
error
(
f
"delete_tag:
{
e
}
"
)
...
...
@@ -231,21 +222,20 @@ class TagTable:
self
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
res
=
(
db
.
query
(
ChatIdTag
)
Session
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
chat_id
=
chat_id
,
user_id
=
user_id
)
.
delete
()
)
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
Session
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
Session
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
return
True
except
Exception
as
e
:
...
...
backend/apps/webui/models/tools.py
View file @
da403f3e
...
...
@@ -3,9 +3,8 @@ from typing import List, Optional
import
time
import
logging
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_s
ession
from
apps.webui.internal.db
import
Base
,
JSONField
,
S
ession
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -95,11 +94,10 @@ class ToolsTable:
)
try
:
with
get_session
()
as
db
:
result
=
Tool
(
**
tool
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
ToolModel
.
model_validate
(
result
)
else
:
...
...
@@ -110,20 +108,17 @@ class ToolsTable:
def
get_tool_by_id
(
self
,
id
:
str
)
->
Optional
[
ToolModel
]:
try
:
with
get_session
()
as
db
:
tool
=
db
.
get
(
Tool
,
id
)
tool
=
Session
.
get
(
Tool
,
id
)
return
ToolModel
.
model_validate
(
tool
)
except
:
return
None
def
get_tools
(
self
)
->
List
[
ToolModel
]:
with
get_session
()
as
db
:
return
[
ToolModel
.
model_validate
(
tool
)
for
tool
in
db
.
query
(
Tool
).
all
()]
return
[
ToolModel
.
model_validate
(
tool
)
for
tool
in
Session
.
query
(
Tool
).
all
()]
def
get_tool_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
with
get_session
()
as
db
:
tool
=
db
.
get
(
Tool
,
id
)
tool
=
Session
.
get
(
Tool
,
id
)
return
tool
.
valves
if
tool
.
valves
else
{}
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
...
...
@@ -131,11 +126,10 @@ class ToolsTable:
def
update_tool_valves_by_id
(
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
ToolValves
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
Session
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
"valves"
:
valves
,
"updated_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
Session
.
commit
()
return
self
.
get_tool_by_id
(
id
)
except
:
return
None
...
...
@@ -183,19 +177,18 @@ class ToolsTable:
def
update_tool_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
ToolModel
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
**
updated
,
"
updated_at
"
:
int
(
time
.
time
())
}
)
db
.
commit
(
)
return
self
.
get_tool_by_id
(
id
)
tool
=
Session
.
get
(
Tool
,
id
)
tool
.
update
(
**
update
d
)
tool
.
updated_at
=
int
(
time
.
time
())
Session
.
commit
(
)
Session
.
refresh
(
tool
)
return
ToolModel
.
model_validate
(
tool
)
except
:
return
None
def
delete_tool_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
delete
()
Session
.
query
(
Tool
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
return
False
...
...
backend/apps/webui/models/users.py
View file @
da403f3e
...
...
@@ -3,11 +3,10 @@ from typing import List, Union, Optional
import
time
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
sqlalchemy.orm
import
Session
from
utils.misc
import
get_gravatar_url
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_s
ession
from
apps.webui.internal.db
import
Base
,
JSONField
,
S
ession
from
apps.webui.models.chats
import
Chats
####################
...
...
@@ -89,7 +88,6 @@ class UsersTable:
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
user
=
UserModel
(
**
{
"id"
:
id
,
...
...
@@ -104,74 +102,66 @@ class UsersTable:
}
)
result
=
User
(
**
user
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
Session
.
add
(
result
)
Session
.
commit
()
Session
.
refresh
(
result
)
if
result
:
return
user
else
:
return
None
def
get_user_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
Exception
as
e
:
return
None
def
get_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
filter_by
(
api_key
=
api_key
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
api_key
=
api_key
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
filter_by
(
email
=
email
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
email
=
email
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_user_by_oauth_sub
(
self
,
sub
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
filter_by
(
oauth_sub
=
sub
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
oauth_sub
=
sub
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_users
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
with
get_session
()
as
db
:
users
=
(
db
.
query
(
User
)
Session
.
query
(
User
)
# .offset(skip).limit(limit)
.
all
()
)
return
[
UserModel
.
model_validate
(
user
)
for
user
in
users
]
def
get_num_users
(
self
)
->
Optional
[
int
]:
with
get_session
()
as
db
:
return
db
.
query
(
User
).
count
()
return
Session
.
query
(
User
).
count
()
def
get_first_user
(
self
)
->
UserModel
:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
order_by
(
User
.
created_at
).
first
()
user
=
Session
.
query
(
User
).
order_by
(
User
.
created_at
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
update_user_role_by_id
(
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"role"
:
role
})
db
.
commit
()
Session
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"role"
:
role
})
Session
.
commit
()
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
...
...
@@ -179,26 +169,24 @@ class UsersTable:
def
update_user_profile_image_url_by_id
(
self
,
id
:
str
,
profile_image_url
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
Session
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"profile_image_url"
:
profile_image_url
}
)
db
.
commit
()
Session
.
commit
()
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
update_user_last_active_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
Session
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"last_active_at"
:
int
(
time
.
time
())}
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
...
...
@@ -206,37 +194,34 @@ class UsersTable:
def
update_user_oauth_sub_by_id
(
self
,
id
:
str
,
oauth_sub
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"oauth_sub"
:
oauth_sub
})
Session
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"oauth_sub"
:
oauth_sub
})
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
updated
)
db
.
commit
()
Session
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
updated
)
Session
.
commit
()
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
# return UserModel(**user.dict())
except
Exception
as
e
:
return
None
def
delete_user_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
# Delete User Chats
result
=
Chats
.
delete_chats_by_user_id
(
id
)
if
result
:
# Delete User
db
.
query
(
User
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
Session
.
query
(
User
).
filter_by
(
id
=
id
).
delete
()
Session
.
commit
()
return
True
else
:
...
...
@@ -245,18 +230,16 @@ class UsersTable:
return
False
def
update_user_api_key_by_id
(
self
,
id
:
str
,
api_key
:
str
)
->
str
:
with
get_session
()
as
db
:
try
:
result
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"api_key"
:
api_key
})
db
.
commit
()
result
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"api_key"
:
api_key
})
Session
.
commit
()
return
True
if
result
==
1
else
False
except
:
return
False
def
get_user_api_key_by_id
(
self
,
id
:
str
)
->
Optional
[
str
]:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
user
=
Session
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
user
.
api_key
except
Exception
as
e
:
return
None
...
...
backend/main.py
View file @
da403f3e
...
...
@@ -29,7 +29,6 @@ from fastapi import HTTPException
from
fastapi.middleware.wsgi
import
WSGIMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
sqlalchemy
import
text
from
sqlalchemy.orm
import
Session
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
...
...
@@ -57,7 +56,7 @@ from apps.webui.main import (
get_pipe_models
,
generate_function_chat_completion
,
)
from
apps.webui.internal.db
import
get_s
ession
,
SessionLocal
from
apps.webui.internal.db
import
S
ession
,
SessionLocal
from
pydantic
import
BaseModel
...
...
@@ -794,6 +793,14 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
@
app
.
middleware
(
"http"
)
async
def
remove_session_after_request
(
request
:
Request
,
call_next
):
response
=
await
call_next
(
request
)
log
.
debug
(
"Removing session after request"
)
Session
.
commit
()
Session
.
remove
()
return
response
@
app
.
middleware
(
"http"
)
async
def
check_url
(
request
:
Request
,
call_next
):
...
...
@@ -2034,8 +2041,7 @@ async def healthcheck():
@
app
.
get
(
"/health/db"
)
async
def
healthcheck_with_db
():
with
get_session
()
as
db
:
result
=
db
.
execute
(
text
(
"SELECT 1;"
)).
all
()
Session
.
execute
(
text
(
"SELECT 1;"
)).
all
()
return
{
"status"
:
True
}
...
...
backend/test/apps/webui/routers/test_chats.py
View file @
da403f3e
...
...
@@ -90,6 +90,8 @@ class TestChats(AbstractPostgresTest):
def
test_get_user_archived_chats
(
self
):
self
.
chats
.
archive_all_chats_by_user_id
(
"2"
)
from
apps.webui.internal.db
import
Session
Session
.
commit
()
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/all/archived"
))
assert
response
.
status_code
==
200
...
...
backend/test/util/abstract_integration_test.py
View file @
da403f3e
...
...
@@ -9,6 +9,7 @@ from pytest_docker.plugin import get_docker_ip
from
fastapi.testclient
import
TestClient
from
sqlalchemy
import
text
,
create_engine
log
=
logging
.
getLogger
(
__name__
)
...
...
@@ -50,11 +51,6 @@ class AbstractPostgresTest(AbstractIntegrationTest):
DOCKER_CONTAINER_NAME
=
"postgres-test-container-will-get-deleted"
docker_client
:
DockerClient
def
get_db
(
self
):
from
apps.webui.internal.db
import
SessionLocal
return
SessionLocal
()
@
classmethod
def
_create_db_url
(
cls
,
env_vars_postgres
:
dict
)
->
str
:
host
=
get_docker_ip
()
...
...
@@ -113,21 +109,21 @@ class AbstractPostgresTest(AbstractIntegrationTest):
pytest
.
fail
(
f
"Could not setup test environment:
{
ex
}
"
)
def
_check_db_connection
(
self
):
from
apps.webui.internal.db
import
Session
retries
=
10
while
retries
>
0
:
try
:
self
.
db_s
ession
.
execute
(
text
(
"SELECT 1"
))
self
.
db_s
ession
.
commit
()
S
ession
.
execute
(
text
(
"SELECT 1"
))
S
ession
.
commit
()
break
except
Exception
as
e
:
self
.
db_s
ession
.
rollback
()
S
ession
.
rollback
()
log
.
warning
(
e
)
time
.
sleep
(
3
)
retries
-=
1
def
setup_method
(
self
):
super
().
setup_method
()
self
.
db_session
=
self
.
get_db
()
self
.
_check_db_connection
()
@
classmethod
...
...
@@ -136,8 +132,9 @@ class AbstractPostgresTest(AbstractIntegrationTest):
cls
.
docker_client
.
containers
.
get
(
cls
.
DOCKER_CONTAINER_NAME
).
remove
(
force
=
True
)
def
teardown_method
(
self
):
from
apps.webui.internal.db
import
Session
# rollback everything not yet committed
self
.
db_s
ession
.
commit
()
S
ession
.
commit
()
# truncate all tables
tables
=
[
...
...
@@ -152,5 +149,5 @@ class AbstractPostgresTest(AbstractIntegrationTest):
'"user"'
,
]
for
table
in
tables
:
self
.
db_s
ession
.
execute
(
text
(
f
"TRUNCATE TABLE
{
table
}
"
))
self
.
db_s
ession
.
commit
()
S
ession
.
execute
(
text
(
f
"TRUNCATE TABLE
{
table
}
"
))
S
ession
.
commit
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment