Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
bee835cb
Commit
bee835cb
authored
Jun 21, 2024
by
Jonathan Rohde
Browse files
feat(sqlalchemy): remove session reference from router
parent
df09d083
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
945 additions
and
874 deletions
+945
-874
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+2
-5
backend/apps/openai/main.py
backend/apps/openai/main.py
+1
-3
backend/apps/webui/internal/db.py
backend/apps/webui/internal/db.py
+5
-4
backend/apps/webui/main.py
backend/apps/webui/main.py
+2
-2
backend/apps/webui/models/auths.py
backend/apps/webui/models/auths.py
+77
-71
backend/apps/webui/models/chats.py
backend/apps/webui/models/chats.py
+179
-159
backend/apps/webui/models/documents.py
backend/apps/webui/models/documents.py
+46
-39
backend/apps/webui/models/files.py
backend/apps/webui/models/files.py
+26
-19
backend/apps/webui/models/functions.py
backend/apps/webui/models/functions.py
+62
-55
backend/apps/webui/models/memories.py
backend/apps/webui/models/memories.py
+35
-28
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+31
-26
backend/apps/webui/models/prompts.py
backend/apps/webui/models/prompts.py
+53
-49
backend/apps/webui/models/tags.py
backend/apps/webui/models/tags.py
+118
-107
backend/apps/webui/models/tools.py
backend/apps/webui/models/tools.py
+37
-33
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+152
-137
backend/apps/webui/routers/auths.py
backend/apps/webui/routers/auths.py
+26
-34
backend/apps/webui/routers/chats.py
backend/apps/webui/routers/chats.py
+56
-60
backend/apps/webui/routers/documents.py
backend/apps/webui/routers/documents.py
+12
-14
backend/apps/webui/routers/files.py
backend/apps/webui/routers/files.py
+12
-15
backend/apps/webui/routers/functions.py
backend/apps/webui/routers/functions.py
+13
-14
No files found.
backend/apps/ollama/main.py
View file @
bee835cb
...
...
@@ -31,7 +31,6 @@ from typing import Optional, List, Union
from
starlette.background
import
BackgroundTask
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
constants
import
ERROR_MESSAGES
...
...
@@ -712,7 +711,6 @@ async def generate_chat_completion(
form_data
:
GenerateChatCompletionForm
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
db
=
Depends
(
get_db
),
):
log
.
debug
(
...
...
@@ -726,7 +724,7 @@ async def generate_chat_completion(
}
model_id
=
form_data
.
model
model_info
=
Models
.
get_model_by_id
(
db
,
model_id
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
if
model_info
.
base_model_id
:
...
...
@@ -885,7 +883,6 @@ async def generate_openai_chat_completion(
form_data
:
dict
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
db
=
Depends
(
get_db
),
):
form_data
=
OpenAIChatCompletionForm
(
**
form_data
)
...
...
@@ -894,7 +891,7 @@ async def generate_openai_chat_completion(
}
model_id
=
form_data
.
model
model_info
=
Models
.
get_model_by_id
(
db
,
model_id
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
if
model_info
.
base_model_id
:
...
...
backend/apps/openai/main.py
View file @
bee835cb
...
...
@@ -11,7 +11,6 @@ import logging
from
pydantic
import
BaseModel
from
starlette.background
import
BackgroundTask
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
constants
import
ERROR_MESSAGES
...
...
@@ -354,13 +353,12 @@ async def generate_chat_completion(
form_data
:
dict
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
db
=
Depends
(
get_db
),
):
idx
=
0
payload
=
{
**
form_data
}
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
db
,
model_id
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
if
model_info
.
base_model_id
:
...
...
backend/apps/webui/internal/db.py
View file @
bee835cb
import
os
import
logging
import
json
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Any
from
typing_extensions
import
Self
...
...
@@ -52,11 +53,12 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
)
else
:
engine
=
create_engine
(
SQLALCHEMY_DATABASE_URL
,
pool_pre_ping
=
True
)
SessionLocal
=
sessionmaker
(
autocommit
=
False
,
autoflush
=
False
,
bind
=
engine
)
SessionLocal
=
sessionmaker
(
autocommit
=
False
,
autoflush
=
False
,
bind
=
engine
,
expire_on_commit
=
False
)
Base
=
declarative_base
()
def
get_db
():
@
contextmanager
def
get_session
():
db
=
SessionLocal
()
try
:
yield
db
...
...
@@ -64,5 +66,4 @@ def get_db():
except
Exception
as
e
:
db
.
rollback
()
raise
e
finally
:
db
.
close
()
backend/apps/webui/main.py
View file @
bee835cb
...
...
@@ -114,8 +114,8 @@ async def get_status():
}
async
def
get_pipe_models
(
db
:
Session
):
pipes
=
Functions
.
get_functions_by_type
(
db
,
"pipe"
,
active_only
=
True
)
async
def
get_pipe_models
():
pipes
=
Functions
.
get_functions_by_type
(
"pipe"
,
active_only
=
True
)
pipe_models
=
[]
for
pipe
in
pipes
:
...
...
backend/apps/webui/models/auths.py
View file @
bee835cb
...
...
@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from
apps.webui.models.users
import
UserModel
,
Users
from
utils.utils
import
verify_password
from
apps.webui.internal.db
import
Base
from
apps.webui.internal.db
import
Base
,
get_session
from
config
import
SRC_LOG_LEVELS
...
...
@@ -96,7 +96,6 @@ class AuthsTable:
def
insert_new_auth
(
self
,
db
:
Session
,
email
:
str
,
password
:
str
,
name
:
str
,
...
...
@@ -104,6 +103,7 @@ class AuthsTable:
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
log
.
info
(
"insert_new_auth"
)
id
=
str
(
uuid
.
uuid4
())
...
...
@@ -115,7 +115,7 @@ class AuthsTable:
db
.
add
(
result
)
user
=
Users
.
insert_new_user
(
db
,
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
)
db
.
commit
()
...
...
@@ -127,14 +127,15 @@ class AuthsTable:
return
None
def
authenticate_user
(
self
,
db
:
Session
,
email
:
str
,
password
:
str
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user:
{
email
}
"
)
with
get_session
()
as
db
:
try
:
auth
=
db
.
query
(
Auth
).
filter_by
(
email
=
email
,
active
=
True
).
first
()
if
auth
:
if
verify_password
(
password
,
auth
.
password
):
user
=
Users
.
get_user_by_id
(
db
,
auth
.
id
)
user
=
Users
.
get_user_by_id
(
auth
.
id
)
return
user
else
:
return
None
...
...
@@ -144,23 +145,25 @@ class AuthsTable:
return
None
def
authenticate_user_by_api_key
(
self
,
db
:
Session
,
api_key
:
str
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user_by_api_key:
{
api_key
}
"
)
with
get_session
()
as
db
:
# if no api_key, return None
if
not
api_key
:
return
None
try
:
user
=
Users
.
get_user_by_api_key
(
db
,
api_key
)
user
=
Users
.
get_user_by_api_key
(
api_key
)
return
user
if
user
else
None
except
:
return
False
def
authenticate_user_by_trusted_header
(
self
,
db
:
Session
,
email
:
str
self
,
email
:
str
)
->
Optional
[
UserModel
]:
log
.
info
(
f
"authenticate_user_by_trusted_header:
{
email
}
"
)
with
get_session
()
as
db
:
try
:
auth
=
db
.
query
(
Auth
).
filter
(
email
=
email
,
active
=
True
).
first
()
if
auth
:
...
...
@@ -170,25 +173,28 @@ class AuthsTable:
return
None
def
update_user_password_by_id
(
self
,
db
:
Session
,
id
:
str
,
new_password
:
str
self
,
id
:
str
,
new_password
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
result
=
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"password"
:
new_password
})
return
True
if
result
==
1
else
False
except
:
return
False
def
update_email_by_id
(
self
,
db
:
Session
,
id
:
str
,
email
:
str
)
->
bool
:
def
update_email_by_id
(
self
,
id
:
str
,
email
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
result
=
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
update
({
"email"
:
email
})
return
True
if
result
==
1
else
False
except
:
return
False
def
delete_auth_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
def
delete_auth_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
# Delete User
result
=
Users
.
delete_user_by_id
(
db
,
id
)
result
=
Users
.
delete_user_by_id
(
id
)
if
result
:
db
.
query
(
Auth
).
filter_by
(
id
=
id
).
delete
()
...
...
backend/apps/webui/models/chats.py
View file @
bee835cb
...
...
@@ -8,7 +8,7 @@ import time
from
sqlalchemy
import
Column
,
String
,
BigInteger
,
Boolean
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
from
apps.webui.internal.db
import
Base
,
get_session
####################
...
...
@@ -80,8 +80,9 @@ class ChatTitleIdResponse(BaseModel):
class
ChatTable
:
def
insert_new_chat
(
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
ChatForm
self
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
with
get_session
()
as
db
:
id
=
str
(
uuid
.
uuid4
())
chat
=
ChatModel
(
**
{
...
...
@@ -103,29 +104,30 @@ class ChatTable:
return
ChatModel
.
model_validate
(
result
)
if
result
else
None
def
update_chat_by_id
(
self
,
db
:
Session
,
id
:
str
,
chat
:
dict
self
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
update
(
{
"chat"
:
json
.
dumps
(
chat
),
"title"
:
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
,
"updated_at"
:
int
(
time
.
time
()),
}
)
chat_obj
=
db
.
get
(
Chat
,
id
)
chat_obj
.
chat
=
json
.
dumps
(
chat
)
chat_obj
.
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
chat_obj
.
updated_at
=
int
(
time
.
time
())
db
.
commit
()
db
.
refresh
(
chat_obj
)
return
self
.
get_chat_by_id
(
db
,
id
)
except
:
return
ChatModel
.
model_validate
(
chat_obj
)
except
Exception
as
e
:
return
None
def
insert_shared_chat_by_chat_id
(
self
,
db
:
Session
,
chat_id
:
str
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
with
get_session
()
as
db
:
# Get the existing chat to share
chat
=
db
.
get
(
Chat
,
chat_id
)
# Check if the chat is already shared
if
chat
.
share_id
:
return
self
.
get_chat_by_id_and_user_id
(
db
,
chat
.
share_id
,
"shared"
)
return
self
.
get_chat_by_id_and_user_id
(
chat
.
share_id
,
"shared"
)
# Create a new chat with the same data, but with a new ID
shared_chat
=
ChatModel
(
**
{
...
...
@@ -149,49 +151,56 @@ class ChatTable:
return
shared_chat
if
(
shared_result
and
result
)
else
None
def
update_shared_chat_by_chat_id
(
self
,
db
:
Session
,
chat_id
:
str
self
,
chat_id
:
str
)
->
Optional
[
ChatModel
]:
with
get_session
()
as
db
:
try
:
print
(
"update_shared_chat_by_id"
)
chat
=
db
.
get
(
Chat
,
chat_id
)
print
(
chat
)
chat
.
title
=
chat
.
title
chat
.
chat
=
chat
.
chat
db
.
commit
()
db
.
refresh
(
chat
)
db
.
query
(
Chat
).
filter_by
(
id
=
chat
.
share_id
).
update
(
{
"title"
:
chat
.
title
,
"chat"
:
chat
.
chat
}
)
return
self
.
get_chat_by_id
(
db
,
chat
.
share_id
)
return
self
.
get_chat_by_id
(
chat
.
share_id
)
except
:
return
None
def
delete_shared_chat_by_chat_id
(
self
,
db
:
Session
,
chat_id
:
str
)
->
bool
:
def
delete_shared_chat_by_chat_id
(
self
,
chat_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
user_id
=
f
"shared-
{
chat_id
}
"
).
delete
()
return
True
except
:
return
False
def
update_chat_share_id_by_id
(
self
,
db
:
Session
,
id
:
str
,
share_id
:
Optional
[
str
]
self
,
id
:
str
,
share_id
:
Optional
[
str
]
)
->
Optional
[
ChatModel
]:
try
:
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
update
({
"share_id"
:
share_id
})
return
self
.
get_chat_by_id
(
db
,
id
)
with
get_session
()
as
db
:
chat
=
db
.
get
(
Chat
,
id
)
chat
.
share_id
=
share_id
db
.
commit
()
db
.
refresh
(
chat
)
return
chat
except
:
return
None
def
toggle_chat_archive_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ChatModel
]:
def
toggle_chat_archive_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
chat
=
self
.
get_chat_by_id
(
db
,
id
)
with
get_session
()
as
db
:
chat
=
self
.
get_chat_by_id
(
id
)
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
update
({
"archived"
:
not
chat
.
archived
})
return
self
.
get_chat_by_id
(
db
,
id
)
return
self
.
get_chat_by_id
(
id
)
except
:
return
None
def
archive_all_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
bool
:
def
archive_all_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
update
({
"archived"
:
True
})
return
True
...
...
@@ -199,8 +208,9 @@ class ChatTable:
return
False
def
get_archived_chat_list_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
...
...
@@ -212,12 +222,12 @@ class ChatTable:
def
get_chat_list_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
,
include_archived
:
bool
=
False
,
skip
:
int
=
0
,
limit
:
int
=
50
,
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
query
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
)
if
not
include_archived
:
query
=
query
.
filter_by
(
archived
=
False
)
...
...
@@ -229,8 +239,9 @@ class ChatTable:
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_list_by_chat_ids
(
self
,
db
:
Session
,
chat_ids
:
List
[
str
],
skip
:
int
=
0
,
limit
:
int
=
50
self
,
chat_ids
:
List
[
str
],
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
.
filter
(
Chat
.
id
.
in_
(
chat_ids
))
...
...
@@ -240,34 +251,38 @@ class ChatTable:
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chat_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ChatModel
]:
def
get_chat_by_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
with
get_session
()
as
db
:
chat
=
db
.
get
(
Chat
,
id
)
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
get_chat_by_share_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ChatModel
]:
def
get_chat_by_share_id
(
self
,
id
:
str
)
->
Optional
[
ChatModel
]:
try
:
with
get_session
()
as
db
:
chat
=
db
.
query
(
Chat
).
filter_by
(
share_id
=
id
).
first
()
if
chat
:
return
self
.
get_chat_by_id
(
db
,
id
)
return
self
.
get_chat_by_id
(
id
)
else
:
return
None
except
Exception
as
e
:
return
None
def
get_chat_by_id_and_user_id
(
self
,
db
:
Session
,
id
:
str
,
user_id
:
str
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
with
get_session
()
as
db
:
chat
=
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
first
()
return
ChatModel
.
model_validate
(
chat
)
except
:
return
None
def
get_chats
(
self
,
db
:
Session
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
def
get_chats
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
# .limit(limit).offset(skip)
...
...
@@ -275,15 +290,17 @@ class ChatTable:
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
List
[
ChatModel
]:
def
get_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
order_by
(
Chat
.
updated_at
.
desc
())
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
get_archived_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
with
get_session
()
as
db
:
all_chats
=
(
db
.
query
(
Chat
)
.
filter_by
(
user_id
=
user_id
,
archived
=
True
)
...
...
@@ -291,34 +308,37 @@ class ChatTable:
)
return
[
ChatModel
.
model_validate
(
chat
)
for
chat
in
all_chats
]
def
delete_chat_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
def
delete_chat_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
id
=
id
).
delete
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
db
,
id
)
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
return
False
def
delete_chat_by_id_and_user_id
(
self
,
db
:
Session
,
id
:
str
,
user_id
:
str
)
->
bool
:
def
delete_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Chat
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
return
True
and
self
.
delete_shared_chat_by_chat_id
(
db
,
id
)
return
True
and
self
.
delete_shared_chat_by_chat_id
(
id
)
except
:
return
False
def
delete_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
bool
:
def
delete_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
self
.
delete_shared_chats_by_user_id
(
db
,
user_id
)
with
get_session
()
as
db
:
self
.
delete_shared_chats_by_user_id
(
user_id
)
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
delete
()
return
True
except
:
return
False
def
delete_shared_chats_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
bool
:
def
delete_shared_chats_by_user_id
(
self
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
chats_by_user
=
db
.
query
(
Chat
).
filter_by
(
user_id
=
user_id
).
all
()
shared_chat_ids
=
[
f
"shared-
{
chat
.
id
}
"
for
chat
in
chats_by_user
]
...
...
backend/apps/webui/models/documents.py
View file @
bee835cb
...
...
@@ -6,7 +6,7 @@ import logging
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
from
apps.webui.internal.db
import
Base
,
get_session
import
json
...
...
@@ -73,7 +73,7 @@ class DocumentForm(DocumentUpdateForm):
class
DocumentsTable
:
def
insert_new_doc
(
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
DocumentForm
self
,
user_id
:
str
,
form_data
:
DocumentForm
)
->
Optional
[
DocumentModel
]:
document
=
DocumentModel
(
**
{
...
...
@@ -84,6 +84,7 @@ class DocumentsTable:
)
try
:
with
get_session
()
as
db
:
result
=
Document
(
**
document
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
...
...
@@ -95,20 +96,23 @@ class DocumentsTable:
except
:
return
None
def
get_doc_by_name
(
self
,
db
:
Session
,
name
:
str
)
->
Optional
[
DocumentModel
]:
def
get_doc_by_name
(
self
,
name
:
str
)
->
Optional
[
DocumentModel
]:
try
:
with
get_session
()
as
db
:
document
=
db
.
query
(
Document
).
filter_by
(
name
=
name
).
first
()
return
DocumentModel
.
model_validate
(
document
)
if
document
else
None
except
:
return
None
def
get_docs
(
self
,
db
:
Session
)
->
List
[
DocumentModel
]:
def
get_docs
(
self
)
->
List
[
DocumentModel
]:
with
get_session
()
as
db
:
return
[
DocumentModel
.
model_validate
(
doc
)
for
doc
in
db
.
query
(
Document
).
all
()]
def
update_doc_by_name
(
self
,
db
:
Session
,
name
:
str
,
form_data
:
DocumentUpdateForm
self
,
name
:
str
,
form_data
:
DocumentUpdateForm
)
->
Optional
[
DocumentModel
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Document
).
filter_by
(
name
=
name
).
update
(
{
"title"
:
form_data
.
title
,
...
...
@@ -116,16 +120,18 @@ class DocumentsTable:
"timestamp"
:
int
(
time
.
time
()),
}
)
return
self
.
get_doc_by_name
(
db
,
form_data
.
name
)
db
.
commit
()
return
self
.
get_doc_by_name
(
form_data
.
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
return
None
def
update_doc_content_by_name
(
self
,
db
:
Session
,
name
:
str
,
updated
:
dict
self
,
name
:
str
,
updated
:
dict
)
->
Optional
[
DocumentModel
]:
try
:
doc
=
self
.
get_doc_by_name
(
db
,
name
)
with
get_session
()
as
db
:
doc
=
self
.
get_doc_by_name
(
name
)
doc_content
=
json
.
loads
(
doc
.
content
if
doc
.
content
else
"{}"
)
doc_content
=
{
**
doc_content
,
**
updated
}
...
...
@@ -135,14 +141,15 @@ class DocumentsTable:
"timestamp"
:
int
(
time
.
time
()),
}
)
return
self
.
get_doc_by_name
(
db
,
name
)
db
.
commit
()
return
self
.
get_doc_by_name
(
name
)
except
Exception
as
e
:
log
.
exception
(
e
)
return
None
def
delete_doc_by_name
(
self
,
db
:
Session
,
name
:
str
)
->
bool
:
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Document
).
filter_by
(
name
=
name
).
delete
()
return
True
except
:
...
...
backend/apps/webui/models/files.py
View file @
bee835cb
...
...
@@ -6,7 +6,7 @@ import logging
from
sqlalchemy
import
Column
,
String
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
JSONField
,
Base
from
apps.webui.internal.db
import
JSONField
,
Base
,
get_session
import
json
...
...
@@ -60,7 +60,7 @@ class FileForm(BaseModel):
class
FilesTable
:
def
insert_new_file
(
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
FileForm
)
->
Optional
[
FileModel
]:
def
insert_new_file
(
self
,
user_id
:
str
,
form_data
:
FileForm
)
->
Optional
[
FileModel
]:
file
=
FileModel
(
**
{
**
form_data
.
model_dump
(),
...
...
@@ -70,6 +70,7 @@ class FilesTable:
)
try
:
with
get_session
()
as
db
:
result
=
File
(
**
file
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
...
...
@@ -82,26 +83,32 @@ class FilesTable:
print
(
f
"Error creating tool:
{
e
}
"
)
return
None
def
get_file_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
FileModel
]:
def
get_file_by_id
(
self
,
id
:
str
)
->
Optional
[
FileModel
]:
try
:
with
get_session
()
as
db
:
file
=
db
.
get
(
File
,
id
)
return
FileModel
.
model_validate
(
file
)
except
:
return
None
def
get_files
(
self
,
db
:
Session
)
->
List
[
FileModel
]:
def
get_files
(
self
)
->
List
[
FileModel
]:
with
get_session
()
as
db
:
return
[
FileModel
.
model_validate
(
file
)
for
file
in
db
.
query
(
File
).
all
()]
def
delete_file_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
def
delete_file_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
File
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
except
:
return
False
def
delete_all_files
(
self
,
db
:
Session
)
->
bool
:
def
delete_all_files
(
self
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
File
).
delete
()
db
.
commit
()
return
True
except
:
return
False
...
...
backend/apps/webui/models/functions.py
View file @
bee835cb
...
...
@@ -6,7 +6,7 @@ import logging
from
sqlalchemy
import
Column
,
String
,
Text
,
BigInteger
,
Boolean
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
JSONField
,
Base
from
apps.webui.internal.db
import
JSONField
,
Base
,
get_session
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -87,7 +87,7 @@ class FunctionValves(BaseModel):
class
FunctionsTable
:
def
insert_new_function
(
self
,
db
:
Session
,
user_id
:
str
,
type
:
str
,
form_data
:
FunctionForm
self
,
user_id
:
str
,
type
:
str
,
form_data
:
FunctionForm
)
->
Optional
[
FunctionModel
]:
function
=
FunctionModel
(
**
{
...
...
@@ -100,6 +100,7 @@ class FunctionsTable:
)
try
:
with
get_session
()
as
db
:
result
=
Function
(
**
function
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
...
...
@@ -112,8 +113,9 @@ class FunctionsTable:
print
(
f
"Error creating tool:
{
e
}
"
)
return
None
def
get_function_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
FunctionModel
]:
def
get_function_by_id
(
self
,
id
:
str
)
->
Optional
[
FunctionModel
]:
try
:
with
get_session
()
as
db
:
function
=
db
.
get
(
Function
,
id
)
return
FunctionModel
.
model_validate
(
function
)
except
:
...
...
@@ -121,35 +123,40 @@ class FunctionsTable:
def
get_functions
(
self
,
active_only
=
False
)
->
List
[
FunctionModel
]:
if
active_only
:
with
get_session
()
as
db
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
().
where
(
Function
.
is_active
==
True
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
db
.
query
(
Function
).
filter_by
(
is_active
=
True
)
.
all
()
]
else
:
with
get_session
()
as
db
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
()
FunctionModel
.
model_
validate
(
function
)
for
function
in
db
.
query
(
Function
).
all
()
]
def
get_functions_by_type
(
self
,
type
:
str
,
active_only
=
False
)
->
List
[
FunctionModel
]:
if
active_only
:
with
get_session
()
as
db
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
().
where
(
Function
.
type
==
type
,
Function
.
is_active
==
True
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
db
.
query
(
Function
).
filter_by
(
type
=
type
,
is_active
=
True
).
all
(
)
]
else
:
with
get_session
()
as
db
:
return
[
FunctionModel
(
**
model_
to_dict
(
function
)
)
for
function
in
Function
.
select
().
where
(
Function
.
type
==
type
)
FunctionModel
.
model_
validate
(
function
)
for
function
in
db
.
query
(
Function
).
filter_by
(
type
=
type
).
all
(
)
]
def
get_function_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
function
=
Function
.
get
(
Function
.
id
==
id
)
with
get_session
()
as
db
:
function
=
db
.
get
(
Function
,
id
)
return
function
.
valves
if
function
.
valves
else
{}
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
...
...
@@ -159,14 +166,12 @@ class FunctionsTable:
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
FunctionValves
]:
try
:
query
=
Function
.
update
(
**
{
"valves"
:
valves
},
updated_at
=
int
(
time
.
time
()),
).
where
(
Function
.
id
==
id
)
query
.
execute
()
function
=
Function
.
get
(
Function
.
id
==
id
)
return
FunctionValves
(
**
model_to_dict
(
function
))
with
get_session
()
as
db
:
db
.
query
(
Function
).
filter_by
(
id
=
id
).
update
(
{
"valves"
:
valves
,
"updated_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
return
self
.
get_function_by_id
(
id
)
except
:
return
None
...
...
@@ -214,29 +219,31 @@ class FunctionsTable:
def
update_function_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
FunctionModel
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Function
).
filter_by
(
id
=
id
).
update
({
**
updated
,
"updated_at"
:
int
(
time
.
time
()),
})
return
self
.
get_function_by_id
(
db
,
id
)
db
.
commit
()
return
self
.
get_function_by_id
(
id
)
except
:
return
None
def
deactivate_all_functions
(
self
)
->
Optional
[
bool
]:
try
:
query
=
Function
.
update
(
**
{
"is_active"
:
False
},
updated_at
=
int
(
time
.
time
()),
)
query
.
execute
()
with
get_session
()
as
db
:
db
.
query
(
Function
).
update
({
"is_active"
:
False
,
"updated_at"
:
int
(
time
.
time
()),
})
db
.
commit
()
return
True
except
:
return
None
def
delete_function_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
def
delete_function_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Function
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
...
...
backend/apps/webui/models/memories.py
View file @
bee835cb
...
...
@@ -4,7 +4,7 @@ from typing import List, Union, Optional
from
sqlalchemy
import
Column
,
String
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
from
apps.webui.internal.db
import
Base
,
get_session
from
apps.webui.models.chats
import
Chats
import
time
...
...
@@ -44,7 +44,6 @@ class MemoriesTable:
def
insert_new_memory
(
self
,
db
:
Session
,
user_id
:
str
,
content
:
str
,
)
->
Optional
[
MemoryModel
]:
...
...
@@ -59,7 +58,8 @@ class MemoriesTable:
"updated_at"
:
int
(
time
.
time
()),
}
)
result
=
Memory
(
**
memory
.
dict
())
with
get_session
()
as
db
:
result
=
Memory
(
**
memory
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
...
...
@@ -70,41 +70,46 @@ class MemoriesTable:
def
update_memory_by_id
(
self
,
db
:
Session
,
id
:
str
,
content
:
str
,
)
->
Optional
[
MemoryModel
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
update
(
{
"content"
:
content
,
"updated_at"
:
int
(
time
.
time
())}
)
return
self
.
get_memory_by_id
(
db
,
id
)
db
.
commit
()
return
self
.
get_memory_by_id
(
id
)
except
:
return
None
def
get_memories
(
self
,
db
:
Session
)
->
List
[
MemoryModel
]:
def
get_memories
(
self
)
->
List
[
MemoryModel
]:
try
:
with
get_session
()
as
db
:
memories
=
db
.
query
(
Memory
).
all
()
return
[
MemoryModel
.
model_validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memories_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
List
[
MemoryModel
]:
def
get_memories_by_user_id
(
self
,
user_id
:
str
)
->
List
[
MemoryModel
]:
try
:
with
get_session
()
as
db
:
memories
=
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
).
all
()
return
[
MemoryModel
.
model_validate
(
memory
)
for
memory
in
memories
]
except
:
return
None
def
get_memory_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
MemoryModel
]:
def
get_memory_by_id
(
self
,
id
:
str
)
->
Optional
[
MemoryModel
]:
try
:
with
get_session
()
as
db
:
memory
=
db
.
get
(
Memory
,
id
)
return
MemoryModel
.
model_validate
(
memory
)
except
:
return
None
def
delete_memory_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
def
delete_memory_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Memory
).
filter_by
(
id
=
id
).
delete
()
return
True
...
...
@@ -113,6 +118,7 @@ class MemoriesTable:
def
delete_memories_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Memory
).
filter_by
(
user_id
=
user_id
).
delete
()
return
True
except
:
...
...
@@ -122,6 +128,7 @@ class MemoriesTable:
self
,
db
:
Session
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Memory
).
filter_by
(
id
=
id
,
user_id
=
user_id
).
delete
()
return
True
except
:
...
...
backend/apps/webui/models/models.py
View file @
bee835cb
...
...
@@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
JSONField
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_session
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
...
...
@@ -78,8 +78,6 @@ class Model(Base):
class
ModelModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
id
:
str
user_id
:
str
base_model_id
:
Optional
[
str
]
=
None
...
...
@@ -91,6 +89,8 @@ class ModelModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -116,7 +116,7 @@ class ModelForm(BaseModel):
class
ModelsTable
:
def
insert_new_model
(
self
,
db
:
Session
,
form_data
:
ModelForm
,
user_id
:
str
self
,
form_data
:
ModelForm
,
user_id
:
str
)
->
Optional
[
ModelModel
]:
model
=
ModelModel
(
**
{
...
...
@@ -127,7 +127,8 @@ class ModelsTable:
}
)
try
:
result
=
Model
(
**
model
.
dict
())
with
get_session
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
...
...
@@ -140,21 +141,24 @@ class ModelsTable:
print
(
e
)
return
None
def
get_all_models
(
self
,
db
:
Session
)
->
List
[
ModelModel
]:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
with
get_session
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ModelModel
]:
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
with
get_session
()
as
db
:
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
except
:
return
None
def
update_model_by_id
(
self
,
db
:
Session
,
id
:
str
,
model
:
ModelForm
self
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
try
:
# update only the fields that are present in the model
with
get_session
()
as
db
:
model
=
db
.
query
(
Model
).
get
(
id
)
model
.
update
(
**
model
.
model_dump
())
db
.
commit
()
...
...
@@ -165,8 +169,9 @@ class ModelsTable:
return
None
def
delete_model_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
...
...
backend/apps/webui/models/prompts.py
View file @
bee835cb
...
...
@@ -5,7 +5,7 @@ import time
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
from
apps.webui.internal.db
import
Base
,
get_session
import
json
...
...
@@ -48,8 +48,9 @@ class PromptForm(BaseModel):
class
PromptsTable
:
def
insert_new_prompt
(
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
PromptForm
self
,
user_id
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
with
get_session
()
as
db
:
prompt
=
PromptModel
(
**
{
"user_id"
:
user_id
,
...
...
@@ -72,32 +73,35 @@ class PromptsTable:
except
Exception
as
e
:
return
None
def
get_prompt_by_command
(
self
,
db
:
Session
,
command
:
str
)
->
Optional
[
PromptModel
]:
def
get_prompt_by_command
(
self
,
command
:
str
)
->
Optional
[
PromptModel
]:
with
get_session
()
as
db
:
try
:
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
return
PromptModel
.
model_validate
(
prompt
)
except
:
return
None
def
get_prompts
(
self
,
db
:
Session
)
->
List
[
PromptModel
]:
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
with
get_session
()
as
db
:
return
[
PromptModel
.
model_validate
(
prompt
)
for
prompt
in
db
.
query
(
Prompt
).
all
()]
def
update_prompt_by_command
(
self
,
db
:
Session
,
command
:
str
,
form_data
:
PromptForm
self
,
command
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
update
(
{
"title"
:
form_data
.
title
,
"content"
:
form_data
.
content
,
"timestamp"
:
int
(
time
.
time
()),
}
)
return
self
.
get_prompt_by_command
(
db
,
command
)
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
prompt
.
title
=
form_data
.
title
prompt
.
content
=
form_data
.
content
prompt
.
timestamp
=
int
(
time
.
time
())
db
.
commit
()
return
prompt
# return self.get_prompt_by_command(command)
except
:
return
None
def
delete_prompt_by_command
(
self
,
db
:
Session
,
command
:
str
)
->
bool
:
def
delete_prompt_by_command
(
self
,
command
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
delete
()
return
True
...
...
backend/apps/webui/models/tags.py
View file @
bee835cb
...
...
@@ -9,7 +9,7 @@ import logging
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
from
apps.webui.internal.db
import
Base
,
get_session
from
config
import
SRC_LOG_LEVELS
...
...
@@ -80,12 +80,13 @@ class ChatTagsResponse(BaseModel):
class
TagTable
:
def
insert_new_tag
(
self
,
db
:
Session
,
name
:
str
,
user_id
:
str
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
id
=
str
(
uuid
.
uuid4
())
tag
=
TagModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"name"
:
name
})
try
:
result
=
Tag
(
**
tag
.
dict
())
with
get_session
()
as
db
:
result
=
Tag
(
**
tag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
...
...
@@ -97,20 +98,21 @@ class TagTable:
return
None
def
get_tag_by_name_and_user_id
(
self
,
db
:
Session
,
name
:
str
,
user_id
:
str
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
try
:
with
get_session
()
as
db
:
tag
=
db
.
query
(
Tag
).
filter
(
name
=
name
,
user_id
=
user_id
).
first
()
return
TagModel
.
model_validate
(
tag
)
except
Exception
as
e
:
return
None
def
add_tag_to_chat
(
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
ChatIdTagForm
self
,
user_id
:
str
,
form_data
:
ChatIdTagForm
)
->
Optional
[
ChatIdTagModel
]:
tag
=
self
.
get_tag_by_name_and_user_id
(
db
,
form_data
.
tag_name
,
user_id
)
tag
=
self
.
get_tag_by_name_and_user_id
(
form_data
.
tag_name
,
user_id
)
if
tag
==
None
:
tag
=
self
.
insert_new_tag
(
db
,
form_data
.
tag_name
,
user_id
)
tag
=
self
.
insert_new_tag
(
form_data
.
tag_name
,
user_id
)
id
=
str
(
uuid
.
uuid4
())
chatIdTag
=
ChatIdTagModel
(
...
...
@@ -123,7 +125,8 @@ class TagTable:
}
)
try
:
result
=
ChatIdTag
(
**
chatIdTag
.
dict
())
with
get_session
()
as
db
:
result
=
ChatIdTag
(
**
chatIdTag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
...
...
@@ -134,7 +137,8 @@ class TagTable:
except
:
return
None
def
get_tags_by_user_id
(
self
,
db
:
Session
,
user_id
:
str
)
->
List
[
TagModel
]:
def
get_tags_by_user_id
(
self
,
user_id
:
str
)
->
List
[
TagModel
]:
with
get_session
()
as
db
:
tag_names
=
[
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
...
...
@@ -156,8 +160,9 @@ class TagTable:
]
def
get_tags_by_chat_id_and_user_id
(
self
,
db
:
Session
,
chat_id
:
str
,
user_id
:
str
self
,
chat_id
:
str
,
user_id
:
str
)
->
List
[
TagModel
]:
with
get_session
()
as
db
:
tag_names
=
[
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
...
...
@@ -179,8 +184,9 @@ class TagTable:
]
def
get_chat_ids_by_tag_name_and_user_id
(
self
,
db
:
Session
,
tag_name
:
str
,
user_id
:
str
self
,
tag_name
:
str
,
user_id
:
str
)
->
List
[
ChatIdTagModel
]:
with
get_session
()
as
db
:
return
[
ChatIdTagModel
.
model_validate
(
chat_id_tag
)
for
chat_id_tag
in
(
...
...
@@ -192,23 +198,26 @@ class TagTable:
]
def
count_chat_ids_by_tag_name_and_user_id
(
self
,
db
:
Session
,
tag_name
:
str
,
user_id
:
str
self
,
tag_name
:
str
,
user_id
:
str
)
->
int
:
with
get_session
()
as
db
:
return
db
.
query
(
ChatIdTag
).
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
).
count
()
def
delete_tag_by_tag_name_and_user_id
(
self
,
db
:
Session
,
tag_name
:
str
,
user_id
:
str
self
,
tag_name
:
str
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
res
=
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
delete
()
)
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
db
,
tag_name
,
user_id
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
...
...
@@ -219,18 +228,20 @@ class TagTable:
return
False
def
delete_tag_by_tag_name_and_chat_id_and_user_id
(
self
,
db
:
Session
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
self
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
res
=
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
chat_id
=
chat_id
,
user_id
=
user_id
)
.
delete
()
)
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
db
,
tag_name
,
user_id
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
...
...
@@ -242,13 +253,13 @@ class TagTable:
return
False
def
delete_tags_by_chat_id_and_user_id
(
self
,
db
:
Session
,
chat_id
:
str
,
user_id
:
str
self
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
tags
=
self
.
get_tags_by_chat_id_and_user_id
(
db
,
chat_id
,
user_id
)
tags
=
self
.
get_tags_by_chat_id_and_user_id
(
chat_id
,
user_id
)
for
tag
in
tags
:
self
.
delete_tag_by_tag_name_and_chat_id_and_user_id
(
db
,
tag
.
tag_name
,
chat_id
,
user_id
tag
.
tag_name
,
chat_id
,
user_id
)
return
True
...
...
backend/apps/webui/models/tools.py
View file @
bee835cb
...
...
@@ -5,7 +5,7 @@ import logging
from
sqlalchemy
import
String
,
Column
,
BigInteger
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
Base
,
JSONField
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_session
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -82,7 +82,7 @@ class ToolValves(BaseModel):
class
ToolsTable
:
def
insert_new_tool
(
self
,
db
:
Session
,
user_id
:
str
,
form_data
:
ToolForm
,
specs
:
List
[
dict
]
self
,
user_id
:
str
,
form_data
:
ToolForm
,
specs
:
List
[
dict
]
)
->
Optional
[
ToolModel
]:
tool
=
ToolModel
(
**
{
...
...
@@ -95,7 +95,8 @@ class ToolsTable:
)
try
:
result
=
Tool
(
**
tool
.
dict
())
with
get_session
()
as
db
:
result
=
Tool
(
**
tool
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
...
...
@@ -107,19 +108,22 @@ class ToolsTable:
print
(
f
"Error creating tool:
{
e
}
"
)
return
None
def
get_tool_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
ToolModel
]:
def
get_tool_by_id
(
self
,
id
:
str
)
->
Optional
[
ToolModel
]:
try
:
with
get_session
()
as
db
:
tool
=
db
.
get
(
Tool
,
id
)
return
ToolModel
.
model_validate
(
tool
)
except
:
return
None
def
get_tools
(
self
,
db
:
Session
)
->
List
[
ToolModel
]:
def
get_tools
(
self
)
->
List
[
ToolModel
]:
with
get_session
()
as
db
:
return
[
ToolModel
.
model_validate
(
tool
)
for
tool
in
db
.
query
(
Tool
).
all
()]
def
get_tool_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
with
get_session
()
as
db
:
tool
=
db
.
get
(
Tool
,
id
)
return
tool
.
valves
if
tool
.
valves
else
{}
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
...
...
@@ -127,14 +131,12 @@ class ToolsTable:
def
update_tool_valves_by_id
(
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
ToolValves
]:
try
:
query
=
Tool
.
update
(
**
{
"valves"
:
valves
},
updated_at
=
int
(
time
.
time
()),
).
where
(
Tool
.
id
==
id
)
query
.
execute
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolValves
(
**
model_to_dict
(
tool
))
with
get_session
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
"valves"
:
valves
,
"updated_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
return
self
.
get_tool_by_id
(
id
)
except
:
return
None
...
...
@@ -172,8 +174,7 @@ class ToolsTable:
user_settings
[
"tools"
][
"valves"
][
id
]
=
valves
# Update the user settings in the database
query
=
Users
.
update_user_by_id
(
user_id
,
{
"settings"
:
user_settings
})
query
.
execute
()
Users
.
update_user_by_id
(
user_id
,
{
"settings"
:
user_settings
})
return
user_settings
[
"tools"
][
"valves"
][
id
]
except
Exception
as
e
:
...
...
@@ -182,15 +183,18 @@ class ToolsTable:
def
update_tool_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
ToolModel
]:
try
:
with
get_session
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
**
updated
,
"updated_at"
:
int
(
time
.
time
())}
)
return
self
.
get_tool_by_id
(
db
,
id
)
db
.
commit
()
return
self
.
get_tool_by_id
(
id
)
except
:
return
None
def
delete_tool_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
def
delete_tool_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_session
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
delete
()
return
True
except
:
...
...
backend/apps/webui/models/users.py
View file @
bee835cb
...
...
@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from
utils.misc
import
get_gravatar_url
from
apps.webui.internal.db
import
Base
,
JSONField
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_session
from
apps.webui.models.chats
import
Chats
####################
...
...
@@ -42,8 +42,6 @@ class UserSettings(BaseModel):
class
UserModel
(
BaseModel
):
model_config
=
ConfigDict
(
from_attributes
=
True
)
id
:
str
name
:
str
email
:
str
...
...
@@ -60,6 +58,8 @@ class UserModel(BaseModel):
oauth_sub
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -82,7 +82,6 @@ class UsersTable:
def
insert_new_user
(
self
,
db
:
Session
,
id
:
str
,
name
:
str
,
email
:
str
,
...
...
@@ -90,6 +89,7 @@ class UsersTable:
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
user
=
UserModel
(
**
{
"id"
:
id
,
...
...
@@ -112,21 +112,24 @@ class UsersTable:
else
:
return
None
def
get_user_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
Exception
as
e
:
return
None
def
get_user_by_api_key
(
self
,
db
:
Session
,
api_key
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
filter_by
(
api_key
=
api_key
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_user_by_email
(
self
,
db
:
Session
,
email
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
filter_by
(
email
=
email
).
first
()
return
UserModel
.
model_validate
(
user
)
...
...
@@ -134,13 +137,15 @@ class UsersTable:
return
None
def
get_user_by_oauth_sub
(
self
,
sub
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
user
=
User
.
get
(
User
.
oauth_sub
==
sub
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
oauth_sub
=
sub
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
get_users
(
self
,
db
:
Session
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
def
get_users
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
with
get_session
()
as
db
:
users
=
(
db
.
query
(
User
)
# .offset(skip).limit(limit)
...
...
@@ -148,10 +153,12 @@ class UsersTable:
)
return
[
UserModel
.
model_validate
(
user
)
for
user
in
users
]
def
get_num_users
(
self
,
db
:
Session
)
->
Optional
[
int
]:
def
get_num_users
(
self
)
->
Optional
[
int
]:
with
get_session
()
as
db
:
return
db
.
query
(
User
).
count
()
def
get_first_user
(
self
,
db
:
Session
)
->
UserModel
:
def
get_first_user
(
self
)
->
UserModel
:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
order_by
(
User
.
created_at
).
first
()
return
UserModel
.
model_validate
(
user
)
...
...
@@ -159,8 +166,9 @@ class UsersTable:
return
None
def
update_user_role_by_id
(
self
,
db
:
Session
,
id
:
str
,
role
:
str
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"role"
:
role
})
db
.
commit
()
...
...
@@ -171,8 +179,9 @@ class UsersTable:
return
None
def
update_user_profile_image_url_by_id
(
self
,
db
:
Session
,
id
:
str
,
profile_image_url
:
str
self
,
id
:
str
,
profile_image_url
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"profile_image_url"
:
profile_image_url
}
...
...
@@ -185,8 +194,9 @@ class UsersTable:
return
None
def
update_user_last_active_by_id
(
self
,
db
:
Session
,
id
:
str
self
,
id
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"last_active_at"
:
int
(
time
.
time
())})
...
...
@@ -196,8 +206,9 @@ class UsersTable:
return
None
def
update_user_oauth_sub_by_id
(
self
,
db
:
Session
,
id
:
str
,
oauth_sub
:
str
self
,
id
:
str
,
oauth_sub
:
str
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"oauth_sub"
:
oauth_sub
})
...
...
@@ -207,8 +218,9 @@ class UsersTable:
return
None
def
update_user_by_id
(
self
,
db
:
Session
,
id
:
str
,
updated
:
dict
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
with
get_session
()
as
db
:
try
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
updated
)
db
.
commit
()
...
...
@@ -219,10 +231,11 @@ class UsersTable:
except
Exception
as
e
:
return
None
def
delete_user_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
bool
:
def
delete_user_by_id
(
self
,
id
:
str
)
->
bool
:
with
get_session
()
as
db
:
try
:
# Delete User Chats
result
=
Chats
.
delete_chats_by_user_id
(
db
,
id
)
result
=
Chats
.
delete_chats_by_user_id
(
id
)
if
result
:
# Delete User
...
...
@@ -235,7 +248,8 @@ class UsersTable:
except
:
return
False
def
update_user_api_key_by_id
(
self
,
db
:
Session
,
id
:
str
,
api_key
:
str
)
->
str
:
def
update_user_api_key_by_id
(
self
,
id
:
str
,
api_key
:
str
)
->
str
:
with
get_session
()
as
db
:
try
:
result
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"api_key"
:
api_key
})
db
.
commit
()
...
...
@@ -243,7 +257,8 @@ class UsersTable:
except
:
return
False
def
get_user_api_key_by_id
(
self
,
db
:
Session
,
id
:
str
)
->
Optional
[
str
]:
def
get_user_api_key_by_id
(
self
,
id
:
str
)
->
Optional
[
str
]:
with
get_session
()
as
db
:
try
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
user
.
api_key
...
...
backend/apps/webui/routers/auths.py
View file @
bee835cb
...
...
@@ -10,7 +10,6 @@ import re
import
uuid
import
csv
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.auths
import
(
SigninForm
,
SignupForm
,
...
...
@@ -80,12 +79,10 @@ async def get_session_user(
@
router
.
post
(
"/update/profile"
,
response_model
=
UserResponse
)
async
def
update_profile
(
form_data
:
UpdateProfileForm
,
session_user
=
Depends
(
get_current_user
),
db
=
Depends
(
get_db
),
session_user
=
Depends
(
get_current_user
)
):
if
session_user
:
user
=
Users
.
update_user_by_id
(
db
,
session_user
.
id
,
{
"profile_image_url"
:
form_data
.
profile_image_url
,
"name"
:
form_data
.
name
},
)
...
...
@@ -105,17 +102,16 @@ async def update_profile(
@
router
.
post
(
"/update/password"
,
response_model
=
bool
)
async
def
update_password
(
form_data
:
UpdatePasswordForm
,
session_user
=
Depends
(
get_current_user
),
db
=
Depends
(
get_db
),
session_user
=
Depends
(
get_current_user
)
):
if
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
ACTION_PROHIBITED
)
if
session_user
:
user
=
Auths
.
authenticate_user
(
db
,
session_user
.
email
,
form_data
.
password
)
user
=
Auths
.
authenticate_user
(
session_user
.
email
,
form_data
.
password
)
if
user
:
hashed
=
get_password_hash
(
form_data
.
new_password
)
return
Auths
.
update_user_password_by_id
(
db
,
user
.
id
,
hashed
)
return
Auths
.
update_user_password_by_id
(
user
.
id
,
hashed
)
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_PASSWORD
)
else
:
...
...
@@ -128,7 +124,7 @@ async def update_password(
@
router
.
post
(
"/signin"
,
response_model
=
SigninResponse
)
async
def
signin
(
request
:
Request
,
response
:
Response
,
form_data
:
SigninForm
,
db
=
Depends
(
get_db
)
):
async
def
signin
(
request
:
Request
,
response
:
Response
,
form_data
:
SigninForm
):
if
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
:
if
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
not
in
request
.
headers
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_TRUSTED_HEADER
)
...
...
@@ -139,34 +135,32 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db
trusted_name
=
request
.
headers
.
get
(
WEBUI_AUTH_TRUSTED_NAME_HEADER
,
trusted_email
)
if
not
Users
.
get_user_by_email
(
db
,
trusted_email
.
lower
()):
if
not
Users
.
get_user_by_email
(
trusted_email
.
lower
()):
await
signup
(
request
,
SignupForm
(
email
=
trusted_email
,
password
=
str
(
uuid
.
uuid4
()),
name
=
trusted_name
),
db
,
)
user
=
Auths
.
authenticate_user_by_trusted_header
(
db
,
trusted_email
)
user
=
Auths
.
authenticate_user_by_trusted_header
(
trusted_email
)
elif
WEBUI_AUTH
==
False
:
admin_email
=
"admin@localhost"
admin_password
=
"admin"
if
Users
.
get_user_by_email
(
db
,
admin_email
.
lower
()):
user
=
Auths
.
authenticate_user
(
db
,
admin_email
.
lower
(),
admin_password
)
if
Users
.
get_user_by_email
(
admin_email
.
lower
()):
user
=
Auths
.
authenticate_user
(
admin_email
.
lower
(),
admin_password
)
else
:
if
Users
.
get_num_users
(
db
)
!=
0
:
if
Users
.
get_num_users
()
!=
0
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EXISTING_USERS
)
await
signup
(
request
,
SignupForm
(
email
=
admin_email
,
password
=
admin_password
,
name
=
"User"
),
db
,
)
user
=
Auths
.
authenticate_user
(
db
,
admin_email
.
lower
(),
admin_password
)
user
=
Auths
.
authenticate_user
(
admin_email
.
lower
(),
admin_password
)
else
:
user
=
Auths
.
authenticate_user
(
db
,
form_data
.
email
.
lower
(),
form_data
.
password
)
user
=
Auths
.
authenticate_user
(
form_data
.
email
.
lower
(),
form_data
.
password
)
if
user
:
token
=
create_token
(
...
...
@@ -200,7 +194,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db
@
router
.
post
(
"/signup"
,
response_model
=
SigninResponse
)
async
def
signup
(
request
:
Request
,
response
:
Response
,
form_data
:
SignupForm
,
db
=
Depends
(
get_db
)
):
async
def
signup
(
request
:
Request
,
response
:
Response
,
form_data
:
SignupForm
):
if
not
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
and
WEBUI_AUTH
:
raise
HTTPException
(
status
.
HTTP_403_FORBIDDEN
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
...
...
@@ -211,18 +205,17 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
INVALID_EMAIL_FORMAT
)
if
Users
.
get_user_by_email
(
db
,
form_data
.
email
.
lower
()):
if
Users
.
get_user_by_email
(
form_data
.
email
.
lower
()):
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EMAIL_TAKEN
)
try
:
role
=
(
"admin"
if
Users
.
get_num_users
(
db
)
==
0
if
Users
.
get_num_users
()
==
0
else
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
)
hashed
=
get_password_hash
(
form_data
.
password
)
user
=
Auths
.
insert_new_auth
(
db
,
form_data
.
email
.
lower
(),
hashed
,
form_data
.
name
,
...
...
@@ -277,7 +270,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db
@
router
.
post
(
"/add"
,
response_model
=
SigninResponse
)
async
def
add_user
(
form_data
:
AddUserForm
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
form_data
:
AddUserForm
,
user
=
Depends
(
get_admin_user
)
):
if
not
validate_email_format
(
form_data
.
email
.
lower
()):
...
...
@@ -285,7 +278,7 @@ async def add_user(
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
INVALID_EMAIL_FORMAT
)
if
Users
.
get_user_by_email
(
db
,
form_data
.
email
.
lower
()):
if
Users
.
get_user_by_email
(
form_data
.
email
.
lower
()):
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EMAIL_TAKEN
)
try
:
...
...
@@ -293,7 +286,6 @@ async def add_user(
print
(
form_data
)
hashed
=
get_password_hash
(
form_data
.
password
)
user
=
Auths
.
insert_new_auth
(
db
,
form_data
.
email
.
lower
(),
hashed
,
form_data
.
name
,
...
...
@@ -325,7 +317,7 @@ async def add_user(
@
router
.
get
(
"/admin/details"
)
async
def
get_admin_details
(
request
:
Request
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
request
:
Request
,
user
=
Depends
(
get_current_user
)
):
if
request
.
app
.
state
.
config
.
SHOW_ADMIN_DETAILS
:
admin_email
=
request
.
app
.
state
.
config
.
ADMIN_EMAIL
...
...
@@ -334,11 +326,11 @@ async def get_admin_details(
print
(
admin_email
,
admin_name
)
if
admin_email
:
admin
=
Users
.
get_user_by_email
(
db
,
admin_email
)
admin
=
Users
.
get_user_by_email
(
admin_email
)
if
admin
:
admin_name
=
admin
.
name
else
:
admin
=
Users
.
get_first_user
(
db
)
admin
=
Users
.
get_first_user
()
if
admin
:
admin_email
=
admin
.
email
admin_name
=
admin
.
name
...
...
@@ -411,9 +403,9 @@ async def update_admin_config(
# create api key
@
router
.
post
(
"/api_key"
,
response_model
=
ApiKey
)
async
def
create_api_key_
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
async
def
create_api_key_
(
user
=
Depends
(
get_current_user
)):
api_key
=
create_api_key
()
success
=
Users
.
update_user_api_key_by_id
(
db
,
user
.
id
,
api_key
)
success
=
Users
.
update_user_api_key_by_id
(
user
.
id
,
api_key
)
if
success
:
return
{
"api_key"
:
api_key
,
...
...
@@ -424,15 +416,15 @@ async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)):
# delete api key
@
router
.
delete
(
"/api_key"
,
response_model
=
bool
)
async
def
delete_api_key
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
success
=
Users
.
update_user_api_key_by_id
(
db
,
user
.
id
,
None
)
async
def
delete_api_key
(
user
=
Depends
(
get_current_user
)):
success
=
Users
.
update_user_api_key_by_id
(
user
.
id
,
None
)
return
success
# get api key
@
router
.
get
(
"/api_key"
,
response_model
=
ApiKey
)
async
def
get_api_key
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
api_key
=
Users
.
get_user_api_key_by_id
(
db
,
user
.
id
)
async
def
get_api_key
(
user
=
Depends
(
get_current_user
)):
api_key
=
Users
.
get_user_api_key_by_id
(
user
.
id
)
if
api_key
:
return
{
"api_key"
:
api_key
,
...
...
backend/apps/webui/routers/chats.py
View file @
bee835cb
...
...
@@ -2,7 +2,6 @@ from fastapi import Depends, Request, HTTPException, status
from
datetime
import
datetime
,
timedelta
from
typing
import
List
,
Union
,
Optional
from
apps.webui.internal.db
import
get_db
from
utils.utils
import
get_current_user
,
get_admin_user
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
...
...
@@ -45,9 +44,9 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
ChatTitleIdResponse
])
@
router
.
get
(
"/list"
,
response_model
=
List
[
ChatTitleIdResponse
])
async
def
get_session_user_chat_list
(
user
=
Depends
(
get_current_user
),
skip
:
int
=
0
,
limit
:
int
=
50
,
db
=
Depends
(
get_db
)
user
=
Depends
(
get_current_user
),
skip
:
int
=
0
,
limit
:
int
=
50
):
return
Chats
.
get_chat_list_by_user_id
(
db
,
user
.
id
,
skip
,
limit
)
return
Chats
.
get_chat_list_by_user_id
(
user
.
id
,
skip
,
limit
)
############################
...
...
@@ -57,7 +56,7 @@ async def get_session_user_chat_list(
@
router
.
delete
(
"/"
,
response_model
=
bool
)
async
def
delete_all_user_chats
(
request
:
Request
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
request
:
Request
,
user
=
Depends
(
get_current_user
)
):
if
(
...
...
@@ -69,7 +68,7 @@ async def delete_all_user_chats(
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
,
)
result
=
Chats
.
delete_chats_by_user_id
(
db
,
user
.
id
)
result
=
Chats
.
delete_chats_by_user_id
(
user
.
id
)
return
result
...
...
@@ -84,10 +83,9 @@ async def get_user_chat_list_by_user_id(
user
=
Depends
(
get_admin_user
),
skip
:
int
=
0
,
limit
:
int
=
50
,
db
=
Depends
(
get_db
),
):
return
Chats
.
get_chat_list_by_user_id
(
db
,
user_id
,
include_archived
=
True
,
skip
=
skip
,
limit
=
limit
user_id
,
include_archived
=
True
,
skip
=
skip
,
limit
=
limit
)
...
...
@@ -98,10 +96,10 @@ async def get_user_chat_list_by_user_id(
@
router
.
post
(
"/new"
,
response_model
=
Optional
[
ChatResponse
])
async
def
create_new_chat
(
form_data
:
ChatForm
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
form_data
:
ChatForm
,
user
=
Depends
(
get_current_user
)
):
try
:
chat
=
Chats
.
insert_new_chat
(
db
,
user
.
id
,
form_data
)
chat
=
Chats
.
insert_new_chat
(
user
.
id
,
form_data
)
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
except
Exception
as
e
:
log
.
exception
(
e
)
...
...
@@ -116,10 +114,10 @@ async def create_new_chat(
@
router
.
get
(
"/all"
,
response_model
=
List
[
ChatResponse
])
async
def
get_user_chats
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
async
def
get_user_chats
(
user
=
Depends
(
get_current_user
)):
return
[
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
for
chat
in
Chats
.
get_chats_by_user_id
(
db
,
user
.
id
)
for
chat
in
Chats
.
get_chats_by_user_id
(
user
.
id
)
]
...
...
@@ -129,10 +127,10 @@ async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)):
@
router
.
get
(
"/all/archived"
,
response_model
=
List
[
ChatResponse
])
async
def
get_user_archived_chats
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
async
def
get_user_archived_chats
(
user
=
Depends
(
get_current_user
)):
return
[
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
for
chat
in
Chats
.
get_archived_chats_by_user_id
(
db
,
user
.
id
)
for
chat
in
Chats
.
get_archived_chats_by_user_id
(
user
.
id
)
]
...
...
@@ -142,7 +140,7 @@ async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get
@
router
.
get
(
"/all/db"
,
response_model
=
List
[
ChatResponse
])
async
def
get_all_user_chats_in_db
(
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
):
async
def
get_all_user_chats_in_db
(
user
=
Depends
(
get_admin_user
)):
if
not
ENABLE_ADMIN_EXPORT
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
...
@@ -150,7 +148,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_
)
return
[
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
for
chat
in
Chats
.
get_chats
(
db
)
for
chat
in
Chats
.
get_chats
()
]
...
...
@@ -161,9 +159,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_
@
router
.
get
(
"/archived"
,
response_model
=
List
[
ChatTitleIdResponse
])
async
def
get_archived_session_user_chat_list
(
user
=
Depends
(
get_current_user
),
skip
:
int
=
0
,
limit
:
int
=
50
,
db
=
Depends
(
get_db
)
user
=
Depends
(
get_current_user
),
skip
:
int
=
0
,
limit
:
int
=
50
):
return
Chats
.
get_archived_chat_list_by_user_id
(
db
,
user
.
id
,
skip
,
limit
)
return
Chats
.
get_archived_chat_list_by_user_id
(
user
.
id
,
skip
,
limit
)
############################
...
...
@@ -172,8 +170,8 @@ async def get_archived_session_user_chat_list(
@
router
.
post
(
"/archive/all"
,
response_model
=
bool
)
async
def
archive_all_chats
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
return
Chats
.
archive_all_chats_by_user_id
(
db
,
user
.
id
)
async
def
archive_all_chats
(
user
=
Depends
(
get_current_user
)):
return
Chats
.
archive_all_chats_by_user_id
(
user
.
id
)
############################
...
...
@@ -183,7 +181,7 @@ async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)):
@
router
.
get
(
"/share/{share_id}"
,
response_model
=
Optional
[
ChatResponse
])
async
def
get_shared_chat_by_id
(
share_id
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
share_id
:
str
,
user
=
Depends
(
get_current_user
)
):
if
user
.
role
==
"pending"
:
raise
HTTPException
(
...
...
@@ -191,9 +189,9 @@ async def get_shared_chat_by_id(
)
if
user
.
role
==
"user"
:
chat
=
Chats
.
get_chat_by_share_id
(
db
,
share_id
)
chat
=
Chats
.
get_chat_by_share_id
(
share_id
)
elif
user
.
role
==
"admin"
:
chat
=
Chats
.
get_chat_by_id
(
db
,
share_id
)
chat
=
Chats
.
get_chat_by_id
(
share_id
)
if
chat
:
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
...
...
@@ -216,23 +214,23 @@ class TagNameForm(BaseModel):
@
router
.
post
(
"/tags"
,
response_model
=
List
[
ChatTitleIdResponse
])
async
def
get_user_chat_list_by_tag_name
(
form_data
:
TagNameForm
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
form_data
:
TagNameForm
,
user
=
Depends
(
get_current_user
)
):
print
(
form_data
)
chat_ids
=
[
chat_id_tag
.
chat_id
for
chat_id_tag
in
Tags
.
get_chat_ids_by_tag_name_and_user_id
(
db
,
form_data
.
name
,
user
.
id
form_data
.
name
,
user
.
id
)
]
chats
=
Chats
.
get_chat_list_by_chat_ids
(
db
,
chat_ids
,
form_data
.
skip
,
form_data
.
limit
chat_ids
,
form_data
.
skip
,
form_data
.
limit
)
if
len
(
chats
)
==
0
:
Tags
.
delete_tag_by_tag_name_and_user_id
(
db
,
form_data
.
name
,
user
.
id
)
Tags
.
delete_tag_by_tag_name_and_user_id
(
form_data
.
name
,
user
.
id
)
return
chats
...
...
@@ -243,9 +241,9 @@ async def get_user_chat_list_by_tag_name(
@
router
.
get
(
"/tags/all"
,
response_model
=
List
[
TagModel
])
async
def
get_all_tags
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
async
def
get_all_tags
(
user
=
Depends
(
get_current_user
)):
try
:
tags
=
Tags
.
get_tags_by_user_id
(
db
,
user
.
id
)
tags
=
Tags
.
get_tags_by_user_id
(
user
.
id
)
return
tags
except
Exception
as
e
:
log
.
exception
(
e
)
...
...
@@ -260,8 +258,8 @@ async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)):
@
router
.
get
(
"/{id}"
,
response_model
=
Optional
[
ChatResponse
])
async
def
get_chat_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
db
,
id
,
user
.
id
)
async
def
get_chat_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
if
chat
:
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
...
...
@@ -278,13 +276,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get
@
router
.
post
(
"/{id}"
,
response_model
=
Optional
[
ChatResponse
])
async
def
update_chat_by_id
(
id
:
str
,
form_data
:
ChatForm
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
id
:
str
,
form_data
:
ChatForm
,
user
=
Depends
(
get_current_user
)
):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
db
,
id
,
user
.
id
)
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
if
chat
:
updated_chat
=
{
**
json
.
loads
(
chat
.
chat
),
**
form_data
.
chat
}
chat
=
Chats
.
update_chat_by_id
(
db
,
id
,
updated_chat
)
chat
=
Chats
.
update_chat_by_id
(
id
,
updated_chat
)
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
else
:
raise
HTTPException
(
...
...
@@ -300,11 +298,11 @@ async def update_chat_by_id(
@
router
.
delete
(
"/{id}"
,
response_model
=
bool
)
async
def
delete_chat_by_id
(
request
:
Request
,
id
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
request
:
Request
,
id
:
str
,
user
=
Depends
(
get_current_user
)
):
if
user
.
role
==
"admin"
:
result
=
Chats
.
delete_chat_by_id
(
db
,
id
)
result
=
Chats
.
delete_chat_by_id
(
id
)
return
result
else
:
if
not
request
.
app
.
state
.
config
.
USER_PERMISSIONS
[
"chat"
][
"deletion"
]:
...
...
@@ -313,7 +311,7 @@ async def delete_chat_by_id(
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
,
)
result
=
Chats
.
delete_chat_by_id_and_user_id
(
db
,
id
,
user
.
id
)
result
=
Chats
.
delete_chat_by_id_and_user_id
(
id
,
user
.
id
)
return
result
...
...
@@ -323,8 +321,8 @@ async def delete_chat_by_id(
@
router
.
get
(
"/{id}/clone"
,
response_model
=
Optional
[
ChatResponse
])
async
def
clone_chat_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
db
,
id
,
user
.
id
)
async
def
clone_chat_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
if
chat
:
chat_body
=
json
.
loads
(
chat
.
chat
)
...
...
@@ -335,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
"title"
:
f
"Clone of
{
chat
.
title
}
"
,
}
chat
=
Chats
.
insert_new_chat
(
db
,
user
.
id
,
ChatForm
(
**
{
"chat"
:
updated_chat
}))
chat
=
Chats
.
insert_new_chat
(
user
.
id
,
ChatForm
(
**
{
"chat"
:
updated_chat
}))
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
else
:
raise
HTTPException
(
...
...
@@ -350,11 +348,11 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
@
router
.
get
(
"/{id}/archive"
,
response_model
=
Optional
[
ChatResponse
])
async
def
archive_chat_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
id
:
str
,
user
=
Depends
(
get_current_user
)
):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
db
,
id
,
user
.
id
)
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
if
chat
:
chat
=
Chats
.
toggle_chat_archive_by_id
(
db
,
id
)
chat
=
Chats
.
toggle_chat_archive_by_id
(
id
)
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
else
:
raise
HTTPException
(
...
...
@@ -368,16 +366,16 @@ async def archive_chat_by_id(
@
router
.
post
(
"/{id}/share"
,
response_model
=
Optional
[
ChatResponse
])
async
def
share_chat_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
db
,
id
,
user
.
id
)
async
def
share_chat_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
if
chat
:
if
chat
.
share_id
:
shared_chat
=
Chats
.
update_shared_chat_by_chat_id
(
db
,
chat
.
id
)
shared_chat
=
Chats
.
update_shared_chat_by_chat_id
(
chat
.
id
)
return
ChatResponse
(
**
{
**
shared_chat
.
model_dump
(),
"chat"
:
json
.
loads
(
shared_chat
.
chat
)}
)
shared_chat
=
Chats
.
insert_shared_chat_by_chat_id
(
db
,
chat
.
id
)
shared_chat
=
Chats
.
insert_shared_chat_by_chat_id
(
chat
.
id
)
if
not
shared_chat
:
raise
HTTPException
(
status_code
=
status
.
HTTP_500_INTERNAL_SERVER_ERROR
,
...
...
@@ -401,15 +399,15 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
@
router
.
delete
(
"/{id}/share"
,
response_model
=
Optional
[
bool
])
async
def
delete_shared_chat_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
id
:
str
,
user
=
Depends
(
get_current_user
)
):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
db
,
id
,
user
.
id
)
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
if
chat
:
if
not
chat
.
share_id
:
return
False
result
=
Chats
.
delete_shared_chat_by_chat_id
(
db
,
id
)
update_result
=
Chats
.
update_chat_share_id_by_id
(
db
,
id
,
None
)
result
=
Chats
.
delete_shared_chat_by_chat_id
(
id
)
update_result
=
Chats
.
update_chat_share_id_by_id
(
id
,
None
)
return
result
and
update_result
!=
None
else
:
...
...
@@ -426,9 +424,9 @@ async def delete_shared_chat_by_id(
@
router
.
get
(
"/{id}/tags"
,
response_model
=
List
[
TagModel
])
async
def
get_chat_tags_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
id
:
str
,
user
=
Depends
(
get_current_user
)
):
tags
=
Tags
.
get_tags_by_chat_id_and_user_id
(
db
,
id
,
user
.
id
)
tags
=
Tags
.
get_tags_by_chat_id_and_user_id
(
id
,
user
.
id
)
if
tags
!=
None
:
return
tags
...
...
@@ -447,13 +445,12 @@ async def get_chat_tags_by_id(
async
def
add_chat_tag_by_id
(
id
:
str
,
form_data
:
ChatIdTagForm
,
user
=
Depends
(
get_current_user
),
db
=
Depends
(
get_db
),
user
=
Depends
(
get_current_user
)
):
tags
=
Tags
.
get_tags_by_chat_id_and_user_id
(
db
,
id
,
user
.
id
)
tags
=
Tags
.
get_tags_by_chat_id_and_user_id
(
id
,
user
.
id
)
if
form_data
.
tag_name
not
in
tags
:
tag
=
Tags
.
add_tag_to_chat
(
db
,
user
.
id
,
form_data
)
tag
=
Tags
.
add_tag_to_chat
(
user
.
id
,
form_data
)
if
tag
:
return
tag
...
...
@@ -478,10 +475,9 @@ async def delete_chat_tag_by_id(
id
:
str
,
form_data
:
ChatIdTagForm
,
user
=
Depends
(
get_current_user
),
db
=
Depends
(
get_db
),
):
result
=
Tags
.
delete_tag_by_tag_name_and_chat_id_and_user_id
(
db
,
form_data
.
tag_name
,
id
,
user
.
id
form_data
.
tag_name
,
id
,
user
.
id
)
if
result
:
...
...
@@ -499,9 +495,9 @@ async def delete_chat_tag_by_id(
@
router
.
delete
(
"/{id}/tags/all"
,
response_model
=
Optional
[
bool
])
async
def
delete_all_chat_tags_by_id
(
id
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
id
:
str
,
user
=
Depends
(
get_current_user
)
):
result
=
Tags
.
delete_tags_by_chat_id_and_user_id
(
db
,
id
,
user
.
id
)
result
=
Tags
.
delete_tags_by_chat_id_and_user_id
(
id
,
user
.
id
)
if
result
:
return
result
...
...
backend/apps/webui/routers/documents.py
View file @
bee835cb
...
...
@@ -6,7 +6,6 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.documents
import
(
Documents
,
DocumentForm
,
...
...
@@ -26,7 +25,7 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
DocumentResponse
])
async
def
get_documents
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
async
def
get_documents
(
user
=
Depends
(
get_current_user
)):
docs
=
[
DocumentResponse
(
**
{
...
...
@@ -34,7 +33,7 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
"content"
:
json
.
loads
(
doc
.
content
if
doc
.
content
else
"{}"
),
}
)
for
doc
in
Documents
.
get_docs
(
db
)
for
doc
in
Documents
.
get_docs
()
]
return
docs
...
...
@@ -46,11 +45,11 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
@
router
.
post
(
"/create"
,
response_model
=
Optional
[
DocumentResponse
])
async
def
create_new_doc
(
form_data
:
DocumentForm
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
form_data
:
DocumentForm
,
user
=
Depends
(
get_admin_user
)
):
doc
=
Documents
.
get_doc_by_name
(
db
,
form_data
.
name
)
doc
=
Documents
.
get_doc_by_name
(
form_data
.
name
)
if
doc
==
None
:
doc
=
Documents
.
insert_new_doc
(
db
,
user
.
id
,
form_data
)
doc
=
Documents
.
insert_new_doc
(
user
.
id
,
form_data
)
if
doc
:
return
DocumentResponse
(
...
...
@@ -78,9 +77,9 @@ async def create_new_doc(
@
router
.
get
(
"/doc"
,
response_model
=
Optional
[
DocumentResponse
])
async
def
get_doc_by_name
(
name
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
name
:
str
,
user
=
Depends
(
get_current_user
)
):
doc
=
Documents
.
get_doc_by_name
(
db
,
name
)
doc
=
Documents
.
get_doc_by_name
(
name
)
if
doc
:
return
DocumentResponse
(
...
...
@@ -112,10 +111,10 @@ class TagDocumentForm(BaseModel):
@
router
.
post
(
"/doc/tags"
,
response_model
=
Optional
[
DocumentResponse
])
async
def
tag_doc_by_name
(
form_data
:
TagDocumentForm
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
form_data
:
TagDocumentForm
,
user
=
Depends
(
get_current_user
)
):
doc
=
Documents
.
update_doc_content_by_name
(
db
,
form_data
.
name
,
{
"tags"
:
form_data
.
tags
}
form_data
.
name
,
{
"tags"
:
form_data
.
tags
}
)
if
doc
:
...
...
@@ -142,9 +141,8 @@ async def update_doc_by_name(
name
:
str
,
form_data
:
DocumentUpdateForm
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
doc
=
Documents
.
update_doc_by_name
(
db
,
name
,
form_data
)
doc
=
Documents
.
update_doc_by_name
(
name
,
form_data
)
if
doc
:
return
DocumentResponse
(
**
{
...
...
@@ -166,7 +164,7 @@ async def update_doc_by_name(
@
router
.
delete
(
"/doc/delete"
,
response_model
=
bool
)
async
def
delete_doc_by_name
(
name
:
str
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
name
:
str
,
user
=
Depends
(
get_admin_user
)
):
result
=
Documents
.
delete_doc_by_name
(
db
,
name
)
result
=
Documents
.
delete_doc_by_name
(
name
)
return
result
backend/apps/webui/routers/files.py
View file @
bee835cb
...
...
@@ -20,7 +20,6 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from
pydantic
import
BaseModel
import
json
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.files
import
(
Files
,
FileForm
,
...
...
@@ -53,8 +52,7 @@ router = APIRouter()
@
router
.
post
(
"/"
)
def
upload_file
(
file
:
UploadFile
=
File
(...),
user
=
Depends
(
get_verified_user
),
db
=
Depends
(
get_db
)
user
=
Depends
(
get_verified_user
)
):
log
.
info
(
f
"file.content_type:
{
file
.
content_type
}
"
)
try
:
...
...
@@ -72,7 +70,6 @@ def upload_file(
f
.
close
()
file
=
Files
.
insert_new_file
(
db
,
user
.
id
,
FileForm
(
**
{
...
...
@@ -109,8 +106,8 @@ def upload_file(
@
router
.
get
(
"/"
,
response_model
=
List
[
FileModel
])
async
def
list_files
(
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
files
=
Files
.
get_files
(
db
)
async
def
list_files
(
user
=
Depends
(
get_verified_user
)):
files
=
Files
.
get_files
()
return
files
...
...
@@ -120,8 +117,8 @@ async def list_files(user=Depends(get_verified_user), db=Depends(get_db)):
@
router
.
delete
(
"/all"
)
async
def
delete_all_files
(
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
):
result
=
Files
.
delete_all_files
(
db
)
async
def
delete_all_files
(
user
=
Depends
(
get_admin_user
)):
result
=
Files
.
delete_all_files
()
if
result
:
folder
=
f
"
{
UPLOAD_DIR
}
"
...
...
@@ -157,8 +154,8 @@ async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)):
@
router
.
get
(
"/{id}"
,
response_model
=
Optional
[
FileModel
])
async
def
get_file_by_id
(
id
:
str
,
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
file
=
Files
.
get_file_by_id
(
db
,
id
)
async
def
get_file_by_id
(
id
:
str
,
user
=
Depends
(
get_verified_user
)):
file
=
Files
.
get_file_by_id
(
id
)
if
file
:
return
file
...
...
@@ -175,8 +172,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(ge
@
router
.
get
(
"/{id}/content"
,
response_model
=
Optional
[
FileModel
])
async
def
get_file_content_by_id
(
id
:
str
,
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
file
=
Files
.
get_file_by_id
(
db
,
id
)
async
def
get_file_content_by_id
(
id
:
str
,
user
=
Depends
(
get_verified_user
)):
file
=
Files
.
get_file_by_id
(
id
)
if
file
:
file_path
=
Path
(
file
.
meta
[
"path"
])
...
...
@@ -226,11 +223,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@
router
.
delete
(
"/{id}"
)
async
def
delete_file_by_id
(
id
:
str
,
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
file
=
Files
.
get_file_by_id
(
db
,
id
)
async
def
delete_file_by_id
(
id
:
str
,
user
=
Depends
(
get_verified_user
)):
file
=
Files
.
get_file_by_id
(
id
)
if
file
:
result
=
Files
.
delete_file_by_id
(
db
,
id
)
result
=
Files
.
delete_file_by_id
(
id
)
if
result
:
return
{
"message"
:
"File deleted successfully"
}
else
:
...
...
backend/apps/webui/routers/functions.py
View file @
bee835cb
...
...
@@ -6,7 +6,6 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.functions
import
(
Functions
,
FunctionForm
,
...
...
@@ -32,8 +31,8 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
FunctionResponse
])
async
def
get_functions
(
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
return
Functions
.
get_functions
(
db
)
async
def
get_functions
(
user
=
Depends
(
get_verified_user
)):
return
Functions
.
get_functions
()
############################
...
...
@@ -42,8 +41,8 @@ async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)):
@
router
.
get
(
"/export"
,
response_model
=
List
[
FunctionModel
])
async
def
get_functions
(
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
):
return
Functions
.
get_functions
(
db
)
async
def
get_functions
(
user
=
Depends
(
get_admin_user
)):
return
Functions
.
get_functions
()
############################
...
...
@@ -53,7 +52,7 @@ async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)):
@
router
.
post
(
"/create"
,
response_model
=
Optional
[
FunctionResponse
])
async
def
create_new_function
(
request
:
Request
,
form_data
:
FunctionForm
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
request
:
Request
,
form_data
:
FunctionForm
,
user
=
Depends
(
get_admin_user
)
):
if
not
form_data
.
id
.
isidentifier
():
raise
HTTPException
(
...
...
@@ -63,7 +62,7 @@ async def create_new_function(
form_data
.
id
=
form_data
.
id
.
lower
()
function
=
Functions
.
get_function_by_id
(
db
,
form_data
.
id
)
function
=
Functions
.
get_function_by_id
(
form_data
.
id
)
if
function
==
None
:
function_path
=
os
.
path
.
join
(
FUNCTIONS_DIR
,
f
"
{
form_data
.
id
}
.py"
)
try
:
...
...
@@ -78,7 +77,7 @@ async def create_new_function(
FUNCTIONS
=
request
.
app
.
state
.
FUNCTIONS
FUNCTIONS
[
form_data
.
id
]
=
function_module
function
=
Functions
.
insert_new_function
(
db
,
user
.
id
,
function_type
,
form_data
)
function
=
Functions
.
insert_new_function
(
user
.
id
,
function_type
,
form_data
)
function_cache_dir
=
Path
(
CACHE_DIR
)
/
"functions"
/
form_data
.
id
function_cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
...
@@ -109,8 +108,8 @@ async def create_new_function(
@
router
.
get
(
"/id/{id}"
,
response_model
=
Optional
[
FunctionModel
])
async
def
get_function_by_id
(
id
:
str
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
):
function
=
Functions
.
get_function_by_id
(
db
,
id
)
async
def
get_function_by_id
(
id
:
str
,
user
=
Depends
(
get_admin_user
)):
function
=
Functions
.
get_function_by_id
(
id
)
if
function
:
return
function
...
...
@@ -155,7 +154,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
@
router
.
post
(
"/id/{id}/update"
,
response_model
=
Optional
[
FunctionModel
])
async
def
update_function_by_id
(
request
:
Request
,
id
:
str
,
form_data
:
FunctionForm
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
request
:
Request
,
id
:
str
,
form_data
:
FunctionForm
,
user
=
Depends
(
get_admin_user
)
):
function_path
=
os
.
path
.
join
(
FUNCTIONS_DIR
,
f
"
{
id
}
.py"
)
...
...
@@ -172,7 +171,7 @@ async def update_function_by_id(
updated
=
{
**
form_data
.
model_dump
(
exclude
=
{
"id"
}),
"type"
:
function_type
}
print
(
updated
)
function
=
Functions
.
update_function_by_id
(
db
,
id
,
updated
)
function
=
Functions
.
update_function_by_id
(
id
,
updated
)
if
function
:
return
function
...
...
@@ -196,9 +195,9 @@ async def update_function_by_id(
@
router
.
delete
(
"/id/{id}/delete"
,
response_model
=
bool
)
async
def
delete_function_by_id
(
request
:
Request
,
id
:
str
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
request
:
Request
,
id
:
str
,
user
=
Depends
(
get_admin_user
)
):
result
=
Functions
.
delete_function_by_id
(
db
,
id
)
result
=
Functions
.
delete_function_by_id
(
id
)
if
result
:
FUNCTIONS
=
request
.
app
.
state
.
FUNCTIONS
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment