Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
bee835cb
Commit
bee835cb
authored
Jun 21, 2024
by
Jonathan Rohde
Browse files
feat(sqlalchemy): remove session reference from router
parent
df09d083
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
278 additions
and
329 deletions
+278
-329
backend/apps/webui/routers/memories.py
backend/apps/webui/routers/memories.py
+10
-13
backend/apps/webui/routers/models.py
backend/apps/webui/routers/models.py
+10
-13
backend/apps/webui/routers/prompts.py
backend/apps/webui/routers/prompts.py
+10
-12
backend/apps/webui/routers/tools.py
backend/apps/webui/routers/tools.py
+10
-13
backend/apps/webui/routers/users.py
backend/apps/webui/routers/users.py
+23
-26
backend/main.py
backend/main.py
+13
-18
backend/migrations/versions/22b5ab2667b8_init.py
backend/migrations/versions/22b5ab2667b8_init.py
+0
-188
backend/migrations/versions/ba76b0bae648_init.py
backend/migrations/versions/ba76b0bae648_init.py
+161
-0
backend/test/apps/webui/routers/test_auths.py
backend/test/apps/webui/routers/test_auths.py
+6
-13
backend/test/apps/webui/routers/test_chats.py
backend/test/apps/webui/routers/test_chats.py
+17
-21
backend/test/apps/webui/routers/test_documents.py
backend/test/apps/webui/routers/test_documents.py
+5
-5
backend/test/apps/webui/routers/test_prompts.py
backend/test/apps/webui/routers/test_prompts.py
+10
-0
backend/test/apps/webui/routers/test_users.py
backend/test/apps/webui/routers/test_users.py
+0
-2
backend/utils/utils.py
backend/utils/utils.py
+3
-5
No files found.
backend/apps/webui/routers/memories.py
View file @
bee835cb
...
...
@@ -7,7 +7,6 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
import
logging
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.memories
import
Memories
,
MemoryModel
from
utils.utils
import
get_verified_user
...
...
@@ -32,8 +31,8 @@ async def get_embeddings(request: Request):
@
router
.
get
(
"/"
,
response_model
=
List
[
MemoryModel
])
async
def
get_memories
(
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
return
Memories
.
get_memories_by_user_id
(
db
,
user
.
id
)
async
def
get_memories
(
user
=
Depends
(
get_verified_user
)):
return
Memories
.
get_memories_by_user_id
(
user
.
id
)
############################
...
...
@@ -54,9 +53,8 @@ async def add_memory(
request
:
Request
,
form_data
:
AddMemoryForm
,
user
=
Depends
(
get_verified_user
),
db
=
Depends
(
get_db
),
):
memory
=
Memories
.
insert_new_memory
(
db
,
user
.
id
,
form_data
.
content
)
memory
=
Memories
.
insert_new_memory
(
user
.
id
,
form_data
.
content
)
memory_embedding
=
request
.
app
.
state
.
EMBEDDING_FUNCTION
(
memory
.
content
)
collection
=
CHROMA_CLIENT
.
get_or_create_collection
(
name
=
f
"user-memory-
{
user
.
id
}
"
)
...
...
@@ -76,9 +74,8 @@ async def update_memory_by_id(
request
:
Request
,
form_data
:
MemoryUpdateModel
,
user
=
Depends
(
get_verified_user
),
db
=
Depends
(
get_db
),
):
memory
=
Memories
.
update_memory_by_id
(
db
,
memory_id
,
form_data
.
content
)
memory
=
Memories
.
update_memory_by_id
(
memory_id
,
form_data
.
content
)
if
memory
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Memory not found"
)
...
...
@@ -129,12 +126,12 @@ async def query_memory(
############################
@
router
.
get
(
"/reset"
,
response_model
=
bool
)
async
def
reset_memory_from_vector_db
(
request
:
Request
,
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
request
:
Request
,
user
=
Depends
(
get_verified_user
)
):
CHROMA_CLIENT
.
delete_collection
(
f
"user-memory-
{
user
.
id
}
"
)
collection
=
CHROMA_CLIENT
.
get_or_create_collection
(
name
=
f
"user-memory-
{
user
.
id
}
"
)
memories
=
Memories
.
get_memories_by_user_id
(
db
,
user
.
id
)
memories
=
Memories
.
get_memories_by_user_id
(
user
.
id
)
for
memory
in
memories
:
memory_embedding
=
request
.
app
.
state
.
EMBEDDING_FUNCTION
(
memory
.
content
)
collection
.
upsert
(
...
...
@@ -151,8 +148,8 @@ async def reset_memory_from_vector_db(
@
router
.
delete
(
"/user"
,
response_model
=
bool
)
async
def
delete_memory_by_user_id
(
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
result
=
Memories
.
delete_memories_by_user_id
(
db
,
user
.
id
)
async
def
delete_memory_by_user_id
(
user
=
Depends
(
get_verified_user
)):
result
=
Memories
.
delete_memories_by_user_id
(
user
.
id
)
if
result
:
try
:
...
...
@@ -171,9 +168,9 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(g
@
router
.
delete
(
"/{memory_id}"
,
response_model
=
bool
)
async
def
delete_memory_by_id
(
memory_id
:
str
,
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
memory_id
:
str
,
user
=
Depends
(
get_verified_user
)
):
result
=
Memories
.
delete_memory_by_id_and_user_id
(
db
,
memory_id
,
user
.
id
)
result
=
Memories
.
delete_memory_by_id_and_user_id
(
memory_id
,
user
.
id
)
if
result
:
collection
=
CHROMA_CLIENT
.
get_or_create_collection
(
...
...
backend/apps/webui/routers/models.py
View file @
bee835cb
...
...
@@ -6,7 +6,6 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.models
import
Models
,
ModelModel
,
ModelForm
,
ModelResponse
from
utils.utils
import
get_verified_user
,
get_admin_user
...
...
@@ -20,8 +19,8 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
ModelResponse
])
async
def
get_models
(
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
return
Models
.
get_all_models
(
db
)
async
def
get_models
(
user
=
Depends
(
get_verified_user
)):
return
Models
.
get_all_models
()
############################
...
...
@@ -34,7 +33,6 @@ async def add_new_model(
request
:
Request
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
if
form_data
.
id
in
request
.
app
.
state
.
MODELS
:
raise
HTTPException
(
...
...
@@ -42,7 +40,7 @@ async def add_new_model(
detail
=
ERROR_MESSAGES
.
MODEL_ID_TAKEN
,
)
else
:
model
=
Models
.
insert_new_model
(
db
,
form_data
,
user
.
id
)
model
=
Models
.
insert_new_model
(
form_data
,
user
.
id
)
if
model
:
return
model
...
...
@@ -59,8 +57,8 @@ async def add_new_model(
@
router
.
get
(
"/{id}"
,
response_model
=
Optional
[
ModelModel
])
async
def
get_model_by_id
(
id
:
str
,
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
model
=
Models
.
get_model_by_id
(
db
,
id
)
async
def
get_model_by_id
(
id
:
str
,
user
=
Depends
(
get_verified_user
)):
model
=
Models
.
get_model_by_id
(
id
)
if
model
:
return
model
...
...
@@ -82,15 +80,14 @@ async def update_model_by_id(
id
:
str
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
model
=
Models
.
get_model_by_id
(
db
,
id
)
model
=
Models
.
get_model_by_id
(
id
)
if
model
:
model
=
Models
.
update_model_by_id
(
db
,
id
,
form_data
)
model
=
Models
.
update_model_by_id
(
id
,
form_data
)
return
model
else
:
if
form_data
.
id
in
request
.
app
.
state
.
MODELS
:
model
=
Models
.
insert_new_model
(
db
,
form_data
,
user
.
id
)
model
=
Models
.
insert_new_model
(
form_data
,
user
.
id
)
if
model
:
return
model
else
:
...
...
@@ -111,6 +108,6 @@ async def update_model_by_id(
@
router
.
delete
(
"/delete"
,
response_model
=
bool
)
async
def
delete_model_by_id
(
id
:
str
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
):
result
=
Models
.
delete_model_by_id
(
db
,
id
)
async
def
delete_model_by_id
(
id
:
str
,
user
=
Depends
(
get_admin_user
)):
result
=
Models
.
delete_model_by_id
(
id
)
return
result
backend/apps/webui/routers/prompts.py
View file @
bee835cb
...
...
@@ -6,7 +6,6 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.prompts
import
Prompts
,
PromptForm
,
PromptModel
from
utils.utils
import
get_current_user
,
get_admin_user
...
...
@@ -20,8 +19,8 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
PromptModel
])
async
def
get_prompts
(
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
):
return
Prompts
.
get_prompts
(
db
)
async
def
get_prompts
(
user
=
Depends
(
get_current_user
)):
return
Prompts
.
get_prompts
()
############################
...
...
@@ -31,11 +30,11 @@ async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
@
router
.
post
(
"/create"
,
response_model
=
Optional
[
PromptModel
])
async
def
create_new_prompt
(
form_data
:
PromptForm
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
form_data
:
PromptForm
,
user
=
Depends
(
get_admin_user
)
):
prompt
=
Prompts
.
get_prompt_by_command
(
db
,
form_data
.
command
)
prompt
=
Prompts
.
get_prompt_by_command
(
form_data
.
command
)
if
prompt
==
None
:
prompt
=
Prompts
.
insert_new_prompt
(
db
,
user
.
id
,
form_data
)
prompt
=
Prompts
.
insert_new_prompt
(
user
.
id
,
form_data
)
if
prompt
:
return
prompt
...
...
@@ -56,9 +55,9 @@ async def create_new_prompt(
@
router
.
get
(
"/command/{command}"
,
response_model
=
Optional
[
PromptModel
])
async
def
get_prompt_by_command
(
command
:
str
,
user
=
Depends
(
get_current_user
)
,
db
=
Depends
(
get_db
)
command
:
str
,
user
=
Depends
(
get_current_user
)
):
prompt
=
Prompts
.
get_prompt_by_command
(
db
,
f
"/
{
command
}
"
)
prompt
=
Prompts
.
get_prompt_by_command
(
f
"/
{
command
}
"
)
if
prompt
:
return
prompt
...
...
@@ -79,9 +78,8 @@ async def update_prompt_by_command(
command
:
str
,
form_data
:
PromptForm
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
prompt
=
Prompts
.
update_prompt_by_command
(
db
,
f
"/
{
command
}
"
,
form_data
)
prompt
=
Prompts
.
update_prompt_by_command
(
f
"/
{
command
}
"
,
form_data
)
if
prompt
:
return
prompt
else
:
...
...
@@ -98,7 +96,7 @@ async def update_prompt_by_command(
@
router
.
delete
(
"/command/{command}/delete"
,
response_model
=
bool
)
async
def
delete_prompt_by_command
(
command
:
str
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
command
:
str
,
user
=
Depends
(
get_admin_user
)
):
result
=
Prompts
.
delete_prompt_by_command
(
db
,
f
"/
{
command
}
"
)
result
=
Prompts
.
delete_prompt_by_command
(
f
"/
{
command
}
"
)
return
result
backend/apps/webui/routers/tools.py
View file @
bee835cb
...
...
@@ -6,7 +6,6 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.users
import
Users
from
apps.webui.models.tools
import
Tools
,
ToolForm
,
ToolModel
,
ToolResponse
from
apps.webui.utils
import
load_toolkit_module_by_id
...
...
@@ -34,7 +33,7 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
ToolResponse
])
async
def
get_toolkits
(
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
async
def
get_toolkits
(
user
=
Depends
(
get_verified_user
)):
toolkits
=
[
toolkit
for
toolkit
in
Tools
.
get_tools
()]
return
toolkits
...
...
@@ -45,8 +44,8 @@ async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
@
router
.
get
(
"/export"
,
response_model
=
List
[
ToolModel
])
async
def
get_toolkits
(
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
):
toolkits
=
[
toolkit
for
toolkit
in
Tools
.
get_tools
(
db
)]
async
def
get_toolkits
(
user
=
Depends
(
get_admin_user
)):
toolkits
=
[
toolkit
for
toolkit
in
Tools
.
get_tools
()]
return
toolkits
...
...
@@ -60,7 +59,6 @@ async def create_new_toolkit(
request
:
Request
,
form_data
:
ToolForm
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
if
not
form_data
.
id
.
isidentifier
():
raise
HTTPException
(
...
...
@@ -70,7 +68,7 @@ async def create_new_toolkit(
form_data
.
id
=
form_data
.
id
.
lower
()
toolkit
=
Tools
.
get_tool_by_id
(
db
,
form_data
.
id
)
toolkit
=
Tools
.
get_tool_by_id
(
form_data
.
id
)
if
toolkit
==
None
:
toolkit_path
=
os
.
path
.
join
(
TOOLS_DIR
,
f
"
{
form_data
.
id
}
.py"
)
try
:
...
...
@@ -84,7 +82,7 @@ async def create_new_toolkit(
TOOLS
[
form_data
.
id
]
=
toolkit_module
specs
=
get_tools_specs
(
TOOLS
[
form_data
.
id
])
toolkit
=
Tools
.
insert_new_tool
(
db
,
user
.
id
,
form_data
,
specs
)
toolkit
=
Tools
.
insert_new_tool
(
user
.
id
,
form_data
,
specs
)
tool_cache_dir
=
Path
(
CACHE_DIR
)
/
"tools"
/
form_data
.
id
tool_cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
...
@@ -115,8 +113,8 @@ async def create_new_toolkit(
@
router
.
get
(
"/id/{id}"
,
response_model
=
Optional
[
ToolModel
])
async
def
get_toolkit_by_id
(
id
:
str
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
):
toolkit
=
Tools
.
get_tool_by_id
(
db
,
id
)
async
def
get_toolkit_by_id
(
id
:
str
,
user
=
Depends
(
get_admin_user
)):
toolkit
=
Tools
.
get_tool_by_id
(
id
)
if
toolkit
:
return
toolkit
...
...
@@ -138,7 +136,6 @@ async def update_toolkit_by_id(
id
:
str
,
form_data
:
ToolForm
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
toolkit_path
=
os
.
path
.
join
(
TOOLS_DIR
,
f
"
{
id
}
.py"
)
...
...
@@ -160,7 +157,7 @@ async def update_toolkit_by_id(
}
print
(
updated
)
toolkit
=
Tools
.
update_tool_by_id
(
db
,
id
,
updated
)
toolkit
=
Tools
.
update_tool_by_id
(
id
,
updated
)
if
toolkit
:
return
toolkit
...
...
@@ -184,9 +181,9 @@ async def update_toolkit_by_id(
@
router
.
delete
(
"/id/{id}/delete"
,
response_model
=
bool
)
async
def
delete_toolkit_by_id
(
request
:
Request
,
id
:
str
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
request
:
Request
,
id
:
str
,
user
=
Depends
(
get_admin_user
)
):
result
=
Tools
.
delete_tool_by_id
(
db
,
id
)
result
=
Tools
.
delete_tool_by_id
(
id
)
if
result
:
TOOLS
=
request
.
app
.
state
.
TOOLS
...
...
backend/apps/webui/routers/users.py
View file @
bee835cb
...
...
@@ -9,7 +9,6 @@ import time
import
uuid
import
logging
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.users
import
(
UserModel
,
UserUpdateForm
,
...
...
@@ -42,9 +41,9 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
UserModel
])
async
def
get_users
(
skip
:
int
=
0
,
limit
:
int
=
50
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
skip
:
int
=
0
,
limit
:
int
=
50
,
user
=
Depends
(
get_admin_user
)
):
return
Users
.
get_users
(
db
,
skip
,
limit
)
return
Users
.
get_users
(
skip
,
limit
)
############################
...
...
@@ -72,11 +71,11 @@ async def update_user_permissions(
@
router
.
post
(
"/update/role"
,
response_model
=
Optional
[
UserModel
])
async
def
update_user_role
(
form_data
:
UserRoleUpdateForm
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
form_data
:
UserRoleUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
if
user
.
id
!=
form_data
.
id
and
form_data
.
id
!=
Users
.
get_first_user
(
db
).
id
:
return
Users
.
update_user_role_by_id
(
db
,
form_data
.
id
,
form_data
.
role
)
if
user
.
id
!=
form_data
.
id
and
form_data
.
id
!=
Users
.
get_first_user
().
id
:
return
Users
.
update_user_role_by_id
(
form_data
.
id
,
form_data
.
role
)
raise
HTTPException
(
status_code
=
status
.
HTTP_403_FORBIDDEN
,
...
...
@@ -91,9 +90,9 @@ async def update_user_role(
@
router
.
get
(
"/user/settings"
,
response_model
=
Optional
[
UserSettings
])
async
def
get_user_settings_by_session_user
(
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
user
=
Depends
(
get_verified_user
)
):
user
=
Users
.
get_user_by_id
(
db
,
user
.
id
)
user
=
Users
.
get_user_by_id
(
user
.
id
)
if
user
:
return
user
.
settings
else
:
...
...
@@ -110,9 +109,9 @@ async def get_user_settings_by_session_user(
@
router
.
post
(
"/user/settings/update"
,
response_model
=
UserSettings
)
async
def
update_user_settings_by_session_user
(
form_data
:
UserSettings
,
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
form_data
:
UserSettings
,
user
=
Depends
(
get_verified_user
)
):
user
=
Users
.
update_user_by_id
(
db
,
user
.
id
,
{
"settings"
:
form_data
.
model_dump
()})
user
=
Users
.
update_user_by_id
(
user
.
id
,
{
"settings"
:
form_data
.
model_dump
()})
if
user
:
return
user
.
settings
else
:
...
...
@@ -129,9 +128,9 @@ async def update_user_settings_by_session_user(
@
router
.
get
(
"/user/info"
,
response_model
=
Optional
[
dict
])
async
def
get_user_info_by_session_user
(
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
user
=
Depends
(
get_verified_user
)
):
user
=
Users
.
get_user_by_id
(
db
,
user
.
id
)
user
=
Users
.
get_user_by_id
(
user
.
id
)
if
user
:
return
user
.
info
else
:
...
...
@@ -148,15 +147,15 @@ async def get_user_info_by_session_user(
@
router
.
post
(
"/user/info/update"
,
response_model
=
Optional
[
dict
])
async
def
update_user_info_by_session_user
(
form_data
:
dict
,
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
form_data
:
dict
,
user
=
Depends
(
get_verified_user
)
):
user
=
Users
.
get_user_by_id
(
db
,
user
.
id
)
user
=
Users
.
get_user_by_id
(
user
.
id
)
if
user
:
if
user
.
info
is
None
:
user
.
info
=
{}
user
=
Users
.
update_user_by_id
(
db
,
user
.
id
,
{
"info"
:
{
**
user
.
info
,
**
form_data
}}
user
.
id
,
{
"info"
:
{
**
user
.
info
,
**
form_data
}}
)
if
user
:
return
user
.
info
...
...
@@ -184,14 +183,14 @@ class UserResponse(BaseModel):
@
router
.
get
(
"/{user_id}"
,
response_model
=
UserResponse
)
async
def
get_user_by_id
(
user_id
:
str
,
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
user_id
:
str
,
user
=
Depends
(
get_verified_user
)
):
# Check if user_id is a shared chat
# If it is, get the user_id from the chat
if
user_id
.
startswith
(
"shared-"
):
chat_id
=
user_id
.
replace
(
"shared-"
,
""
)
chat
=
Chats
.
get_chat_by_id
(
db
,
chat_id
)
chat
=
Chats
.
get_chat_by_id
(
chat_id
)
if
chat
:
user_id
=
chat
.
user_id
else
:
...
...
@@ -200,7 +199,7 @@ async def get_user_by_id(
detail
=
ERROR_MESSAGES
.
USER_NOT_FOUND
,
)
user
=
Users
.
get_user_by_id
(
db
,
user_id
)
user
=
Users
.
get_user_by_id
(
user_id
)
if
user
:
return
UserResponse
(
name
=
user
.
name
,
profile_image_url
=
user
.
profile_image_url
)
...
...
@@ -221,13 +220,12 @@ async def update_user_by_id(
user_id
:
str
,
form_data
:
UserUpdateForm
,
session_user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
user
=
Users
.
get_user_by_id
(
db
,
user_id
)
user
=
Users
.
get_user_by_id
(
user_id
)
if
user
:
if
form_data
.
email
.
lower
()
!=
user
.
email
:
email_user
=
Users
.
get_user_by_email
(
db
,
form_data
.
email
.
lower
())
email_user
=
Users
.
get_user_by_email
(
form_data
.
email
.
lower
())
if
email_user
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
...
...
@@ -237,11 +235,10 @@ async def update_user_by_id(
if
form_data
.
password
:
hashed
=
get_password_hash
(
form_data
.
password
)
log
.
debug
(
f
"hashed:
{
hashed
}
"
)
Auths
.
update_user_password_by_id
(
db
,
user_id
,
hashed
)
Auths
.
update_user_password_by_id
(
user_id
,
hashed
)
Auths
.
update_email_by_id
(
db
,
user_id
,
form_data
.
email
.
lower
())
Auths
.
update_email_by_id
(
user_id
,
form_data
.
email
.
lower
())
updated_user
=
Users
.
update_user_by_id
(
db
,
user_id
,
{
"name"
:
form_data
.
name
,
...
...
@@ -271,10 +268,10 @@ async def update_user_by_id(
@
router
.
delete
(
"/{user_id}"
,
response_model
=
bool
)
async
def
delete_user_by_id
(
user_id
:
str
,
user
=
Depends
(
get_admin_user
)
,
db
=
Depends
(
get_db
)
user_id
:
str
,
user
=
Depends
(
get_admin_user
)
):
if
user
.
id
!=
user_id
:
result
=
Auths
.
delete_auth_by_id
(
db
,
user_id
)
result
=
Auths
.
delete_auth_by_id
(
user_id
)
if
result
:
return
True
...
...
backend/main.py
View file @
bee835cb
...
...
@@ -57,7 +57,7 @@ from apps.webui.main import (
get_pipe_models
,
generate_function_chat_completion
,
)
from
apps.webui.internal.db
import
get_
db
,
SessionLocal
from
apps.webui.internal.db
import
get_
session
,
SessionLocal
from
pydantic
import
BaseModel
...
...
@@ -410,7 +410,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user
=
get_current_user
(
request
,
get_http_authorization_cred
(
request
.
headers
.
get
(
"Authorization"
)),
SessionLocal
(),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files
=
False
...
...
@@ -800,9 +799,7 @@ app.add_middleware(
@
app
.
middleware
(
"http"
)
async
def
check_url
(
request
:
Request
,
call_next
):
if
len
(
app
.
state
.
MODELS
)
==
0
:
db
=
SessionLocal
()
await
get_all_models
(
db
)
db
.
commit
()
await
get_all_models
()
else
:
pass
...
...
@@ -836,12 +833,12 @@ app.mount("/api/v1", webui_app)
webui_app
.
state
.
EMBEDDING_FUNCTION
=
rag_app
.
state
.
EMBEDDING_FUNCTION
async
def
get_all_models
(
db
:
Session
):
async
def
get_all_models
():
pipe_models
=
[]
openai_models
=
[]
ollama_models
=
[]
pipe_models
=
await
get_pipe_models
(
db
)
pipe_models
=
await
get_pipe_models
()
if
app
.
state
.
config
.
ENABLE_OPENAI_API
:
openai_models
=
await
get_openai_models
()
...
...
@@ -863,7 +860,7 @@ async def get_all_models(db: Session):
models
=
pipe_models
+
openai_models
+
ollama_models
custom_models
=
Models
.
get_all_models
(
db
)
custom_models
=
Models
.
get_all_models
()
for
custom_model
in
custom_models
:
if
custom_model
.
base_model_id
==
None
:
for
model
in
models
:
...
...
@@ -903,8 +900,8 @@ async def get_all_models(db: Session):
@
app
.
get
(
"/api/models"
)
async
def
get_models
(
user
=
Depends
(
get_verified_user
)
,
db
=
Depends
(
get_db
)
):
models
=
await
get_all_models
(
db
)
async
def
get_models
(
user
=
Depends
(
get_verified_user
)):
models
=
await
get_all_models
()
# Filter out filter pipelines
models
=
[
...
...
@@ -1608,9 +1605,8 @@ async def get_pipeline_valves(
urlIdx
:
Optional
[
int
],
pipeline_id
:
str
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
models
=
await
get_all_models
(
db
)
models
=
await
get_all_models
()
r
=
None
try
:
...
...
@@ -1649,9 +1645,8 @@ async def get_pipeline_valves_spec(
urlIdx
:
Optional
[
int
],
pipeline_id
:
str
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
models
=
await
get_all_models
(
db
)
models
=
await
get_all_models
()
r
=
None
try
:
...
...
@@ -1690,9 +1685,8 @@ async def update_pipeline_valves(
pipeline_id
:
str
,
form_data
:
dict
,
user
=
Depends
(
get_admin_user
),
db
=
Depends
(
get_db
),
):
models
=
await
get_all_models
(
db
)
models
=
await
get_all_models
()
r
=
None
try
:
...
...
@@ -2040,7 +2034,8 @@ async def healthcheck():
@
app
.
get
(
"/health/db"
)
async
def
healthcheck_with_db
(
db
:
Session
=
Depends
(
get_db
)):
async
def
healthcheck_with_db
():
with
get_session
()
as
db
:
result
=
db
.
execute
(
text
(
"SELECT 1;"
)).
all
()
return
{
"status"
:
True
}
...
...
backend/migrations/versions/22b5ab2667b8_init.py
deleted
100644 → 0
View file @
df09d083
"""init
Revision ID: 22b5ab2667b8
Revises:
Create Date: 2024-06-20 13:22:40.397002
"""
from
typing
import
Sequence
,
Union
from
alembic
import
op
import
sqlalchemy
as
sa
from
sqlalchemy.engine.reflection
import
Inspector
import
apps.webui.internal.db
# revision identifiers, used by Alembic.
revision
:
str
=
"22b5ab2667b8"
down_revision
:
Union
[
str
,
None
]
=
None
branch_labels
:
Union
[
str
,
Sequence
[
str
],
None
]
=
None
depends_on
:
Union
[
str
,
Sequence
[
str
],
None
]
=
None
def
upgrade
()
->
None
:
con
=
op
.
get_bind
()
inspector
=
Inspector
.
from_engine
(
con
)
tables
=
set
(
inspector
.
get_table_names
())
# ### commands auto generated by Alembic - please adjust! ###
if
not
"auth"
in
tables
:
op
.
create_table
(
"auth"
,
sa
.
Column
(
"id"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"email"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"password"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"active"
,
sa
.
Boolean
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"id"
),
)
if
not
"chat"
in
tables
:
op
.
create_table
(
"chat"
,
sa
.
Column
(
"id"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"user_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"title"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"chat"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"created_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
"updated_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
"share_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"archived"
,
sa
.
Boolean
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"id"
),
sa
.
UniqueConstraint
(
"share_id"
),
)
if
not
"chatidtag"
in
tables
:
op
.
create_table
(
"chatidtag"
,
sa
.
Column
(
"id"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"tag_name"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"chat_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"user_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"timestamp"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"id"
),
)
if
not
"document"
in
tables
:
op
.
create_table
(
"document"
,
sa
.
Column
(
"collection_name"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"name"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"title"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"filename"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"content"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"user_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"timestamp"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"collection_name"
),
sa
.
UniqueConstraint
(
"name"
),
)
if
not
"memory"
in
tables
:
op
.
create_table
(
"memory"
,
sa
.
Column
(
"id"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"user_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"content"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"updated_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
"created_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"id"
),
)
if
not
"model"
in
tables
:
op
.
create_table
(
"model"
,
sa
.
Column
(
"id"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"user_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"base_model_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"name"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"params"
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
"meta"
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
"updated_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
"created_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"id"
),
)
if
not
"prompt"
in
tables
:
op
.
create_table
(
"prompt"
,
sa
.
Column
(
"command"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"user_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"title"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"content"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"timestamp"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"command"
),
)
if
not
"tag"
in
tables
:
op
.
create_table
(
"tag"
,
sa
.
Column
(
"id"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"name"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"user_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"data"
,
sa
.
String
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"id"
),
)
if
not
"tool"
in
tables
:
op
.
create_table
(
"tool"
,
sa
.
Column
(
"id"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"user_id"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"name"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"content"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"specs"
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
"meta"
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
"updated_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
"created_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"id"
),
)
if
not
"user"
in
tables
:
op
.
create_table
(
"user"
,
sa
.
Column
(
"id"
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
"name"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"email"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"role"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"profile_image_url"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"last_active_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
"updated_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
"created_at"
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
"api_key"
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
"settings"
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
"info"
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
"id"
),
sa
.
UniqueConstraint
(
"api_key"
),
)
if
not
"file"
in
tables
:
op
.
create_table
(
'file'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'filename'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'meta'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'created_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
if
not
"function"
in
tables
:
op
.
create_table
(
'function'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'name'
,
sa
.
Text
(),
nullable
=
True
),
sa
.
Column
(
'type'
,
sa
.
Text
(),
nullable
=
True
),
sa
.
Column
(
'content'
,
sa
.
Text
(),
nullable
=
True
),
sa
.
Column
(
'meta'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'updated_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'created_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
# ### end Alembic commands ###
def
downgrade
()
->
None
:
# ### commands auto generated by Alembic - please adjust! ###
# do nothing as we assume we had previous migrations from peewee-migrate
pass
# ### end Alembic commands ###
backend/migrations/versions/ba76b0bae648_init.py
0 → 100644
View file @
bee835cb
"""init
Revision ID: ba76b0bae648
Revises:
Create Date: 2024-06-24 09:09:11.636336
"""
from
typing
import
Sequence
,
Union
from
alembic
import
op
import
sqlalchemy
as
sa
import
apps.webui.internal.db
# revision identifiers, used by Alembic.
revision
:
str
=
'ba76b0bae648'
down_revision
:
Union
[
str
,
None
]
=
None
branch_labels
:
Union
[
str
,
Sequence
[
str
],
None
]
=
None
depends_on
:
Union
[
str
,
Sequence
[
str
],
None
]
=
None
def
upgrade
()
->
None
:
# ### commands auto generated by Alembic - please adjust! ###
op
.
create_table
(
'auth'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'email'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'password'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'active'
,
sa
.
Boolean
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
op
.
create_table
(
'chat'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'title'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'chat'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'created_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'updated_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'share_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'archived'
,
sa
.
Boolean
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
),
sa
.
UniqueConstraint
(
'share_id'
)
)
op
.
create_table
(
'chatidtag'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'tag_name'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'chat_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'timestamp'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
op
.
create_table
(
'document'
,
sa
.
Column
(
'collection_name'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'name'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'title'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'filename'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'content'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'timestamp'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'collection_name'
),
sa
.
UniqueConstraint
(
'name'
)
)
op
.
create_table
(
'file'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'filename'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'meta'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'created_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
op
.
create_table
(
'function'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'name'
,
sa
.
Text
(),
nullable
=
True
),
sa
.
Column
(
'type'
,
sa
.
Text
(),
nullable
=
True
),
sa
.
Column
(
'content'
,
sa
.
Text
(),
nullable
=
True
),
sa
.
Column
(
'meta'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'valves'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'is_active'
,
sa
.
Boolean
(),
nullable
=
True
),
sa
.
Column
(
'updated_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'created_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
op
.
create_table
(
'memory'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'content'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'updated_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'created_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
op
.
create_table
(
'model'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'base_model_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'name'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'params'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'meta'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'updated_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'created_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
op
.
create_table
(
'prompt'
,
sa
.
Column
(
'command'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'title'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'content'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'timestamp'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'command'
)
)
op
.
create_table
(
'tag'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'name'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'data'
,
sa
.
String
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
op
.
create_table
(
'tool'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'user_id'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'name'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'content'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'specs'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'meta'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'valves'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'updated_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'created_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
)
)
op
.
create_table
(
'user'
,
sa
.
Column
(
'id'
,
sa
.
String
(),
nullable
=
False
),
sa
.
Column
(
'name'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'email'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'role'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'profile_image_url'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'last_active_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'updated_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'created_at'
,
sa
.
BigInteger
(),
nullable
=
True
),
sa
.
Column
(
'api_key'
,
sa
.
String
(),
nullable
=
True
),
sa
.
Column
(
'settings'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
Column
(
'info'
,
apps
.
webui
.
internal
.
db
.
JSONField
(),
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
),
sa
.
UniqueConstraint
(
'api_key'
)
)
# ### end Alembic commands ###
def
downgrade
()
->
None
:
# ### commands auto generated by Alembic - please adjust! ###
op
.
drop_table
(
'user'
)
op
.
drop_table
(
'tool'
)
op
.
drop_table
(
'tag'
)
op
.
drop_table
(
'prompt'
)
op
.
drop_table
(
'model'
)
op
.
drop_table
(
'memory'
)
op
.
drop_table
(
'function'
)
op
.
drop_table
(
'file'
)
op
.
drop_table
(
'document'
)
op
.
drop_table
(
'chatidtag'
)
op
.
drop_table
(
'chat'
)
op
.
drop_table
(
'auth'
)
# ### end Alembic commands ###
backend/test/apps/webui/routers/test_auths.py
View file @
bee835cb
...
...
@@ -31,7 +31,6 @@ class TestAuths(AbstractPostgresTest):
from
utils.utils
import
get_password_hash
user
=
self
.
auths
.
insert_new_auth
(
self
.
db_session
,
email
=
"john.doe@openwebui.com"
,
password
=
get_password_hash
(
"old_password"
),
name
=
"John Doe"
,
...
...
@@ -45,7 +44,7 @@ class TestAuths(AbstractPostgresTest):
json
=
{
"name"
:
"John Doe 2"
,
"profile_image_url"
:
"/user2.png"
},
)
assert
response
.
status_code
==
200
db_user
=
self
.
users
.
get_user_by_id
(
self
.
db_session
,
user
.
id
)
db_user
=
self
.
users
.
get_user_by_id
(
user
.
id
)
assert
db_user
.
name
==
"John Doe 2"
assert
db_user
.
profile_image_url
==
"/user2.png"
...
...
@@ -53,7 +52,6 @@ class TestAuths(AbstractPostgresTest):
from
utils.utils
import
get_password_hash
user
=
self
.
auths
.
insert_new_auth
(
self
.
db_session
,
email
=
"john.doe@openwebui.com"
,
password
=
get_password_hash
(
"old_password"
),
name
=
"John Doe"
,
...
...
@@ -69,11 +67,11 @@ class TestAuths(AbstractPostgresTest):
assert
response
.
status_code
==
200
old_auth
=
self
.
auths
.
authenticate_user
(
self
.
db_session
,
"john.doe@openwebui.com"
,
"old_password"
"john.doe@openwebui.com"
,
"old_password"
)
assert
old_auth
is
None
new_auth
=
self
.
auths
.
authenticate_user
(
self
.
db_session
,
"john.doe@openwebui.com"
,
"new_password"
"john.doe@openwebui.com"
,
"new_password"
)
assert
new_auth
is
not
None
...
...
@@ -81,7 +79,6 @@ class TestAuths(AbstractPostgresTest):
from
utils.utils
import
get_password_hash
user
=
self
.
auths
.
insert_new_auth
(
self
.
db_session
,
email
=
"john.doe@openwebui.com"
,
password
=
get_password_hash
(
"password"
),
name
=
"John Doe"
,
...
...
@@ -144,7 +141,6 @@ class TestAuths(AbstractPostgresTest):
def
test_get_admin_details
(
self
):
self
.
auths
.
insert_new_auth
(
self
.
db_session
,
email
=
"john.doe@openwebui.com"
,
password
=
"password"
,
name
=
"John Doe"
,
...
...
@@ -162,7 +158,6 @@ class TestAuths(AbstractPostgresTest):
def
test_create_api_key_
(
self
):
user
=
self
.
auths
.
insert_new_auth
(
self
.
db_session
,
email
=
"john.doe@openwebui.com"
,
password
=
"password"
,
name
=
"John Doe"
,
...
...
@@ -178,31 +173,29 @@ class TestAuths(AbstractPostgresTest):
def
test_delete_api_key
(
self
):
user
=
self
.
auths
.
insert_new_auth
(
self
.
db_session
,
email
=
"john.doe@openwebui.com"
,
password
=
"password"
,
name
=
"John Doe"
,
profile_image_url
=
"/user.png"
,
role
=
"admin"
,
)
self
.
users
.
update_user_api_key_by_id
(
self
.
db_session
,
user
.
id
,
"abc"
)
self
.
users
.
update_user_api_key_by_id
(
user
.
id
,
"abc"
)
with
mock_webui_user
(
id
=
user
.
id
):
response
=
self
.
fast_api_client
.
delete
(
self
.
create_url
(
"/api_key"
))
assert
response
.
status_code
==
200
assert
response
.
json
()
==
True
db_user
=
self
.
users
.
get_user_by_id
(
self
.
db_session
,
user
.
id
)
db_user
=
self
.
users
.
get_user_by_id
(
user
.
id
)
assert
db_user
.
api_key
is
None
def
test_get_api_key
(
self
):
user
=
self
.
auths
.
insert_new_auth
(
self
.
db_session
,
email
=
"john.doe@openwebui.com"
,
password
=
"password"
,
name
=
"John Doe"
,
profile_image_url
=
"/user.png"
,
role
=
"admin"
,
)
self
.
users
.
update_user_api_key_by_id
(
self
.
db_session
,
user
.
id
,
"abc"
)
self
.
users
.
update_user_api_key_by_id
(
user
.
id
,
"abc"
)
with
mock_webui_user
(
id
=
user
.
id
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/api_key"
))
assert
response
.
status_code
==
200
...
...
backend/test/apps/webui/routers/test_chats.py
View file @
bee835cb
...
...
@@ -18,7 +18,6 @@ class TestChats(AbstractPostgresTest):
self
.
chats
=
Chats
self
.
chats
.
insert_new_chat
(
self
.
db_session
,
"2"
,
ChatForm
(
**
{
...
...
@@ -46,7 +45,7 @@ class TestChats(AbstractPostgresTest):
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
delete
(
self
.
create_url
(
"/"
))
assert
response
.
status_code
==
200
assert
len
(
self
.
chats
.
get_chats
(
self
.
db_session
))
==
0
assert
len
(
self
.
chats
.
get_chats
())
==
0
def
test_get_user_chat_list_by_user_id
(
self
):
with
mock_webui_user
(
id
=
"3"
):
...
...
@@ -84,14 +83,13 @@ class TestChats(AbstractPostgresTest):
assert
data
[
"title"
]
==
"New Chat"
assert
data
[
"updated_at"
]
is
not
None
assert
data
[
"created_at"
]
is
not
None
assert
len
(
self
.
chats
.
get_chats
(
self
.
db_session
))
==
2
assert
len
(
self
.
chats
.
get_chats
())
==
2
def
test_get_user_chats
(
self
):
self
.
test_get_session_user_chat_list
()
def
test_get_user_archived_chats
(
self
):
self
.
chats
.
archive_all_chats_by_user_id
(
self
.
db_session
,
"2"
)
self
.
db_session
.
commit
()
self
.
chats
.
archive_all_chats_by_user_id
(
"2"
)
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/all/archived"
))
assert
response
.
status_code
==
200
...
...
@@ -114,12 +112,11 @@ class TestChats(AbstractPostgresTest):
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
"/archive/all"
))
assert
response
.
status_code
==
200
assert
len
(
self
.
chats
.
get_archived_chats_by_user_id
(
self
.
db_session
,
"2"
))
==
1
assert
len
(
self
.
chats
.
get_archived_chats_by_user_id
(
"2"
))
==
1
def
test_get_shared_chat_by_id
(
self
):
chat_id
=
self
.
chats
.
get_chats
(
self
.
db_session
)[
0
].
id
self
.
chats
.
update_chat_share_id_by_id
(
self
.
db_session
,
chat_id
,
chat_id
)
self
.
db_session
.
commit
()
chat_id
=
self
.
chats
.
get_chats
()[
0
].
id
self
.
chats
.
update_chat_share_id_by_id
(
chat_id
,
chat_id
)
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
f
"/share/
{
chat_id
}
"
))
assert
response
.
status_code
==
200
...
...
@@ -136,7 +133,7 @@ class TestChats(AbstractPostgresTest):
assert
data
[
"title"
]
==
"New Chat"
def
test_get_chat_by_id
(
self
):
chat_id
=
self
.
chats
.
get_chats
(
self
.
db_session
)[
0
].
id
chat_id
=
self
.
chats
.
get_chats
()[
0
].
id
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
f
"/
{
chat_id
}
"
))
assert
response
.
status_code
==
200
...
...
@@ -153,7 +150,7 @@ class TestChats(AbstractPostgresTest):
assert
data
[
"user_id"
]
==
"2"
def
test_update_chat_by_id
(
self
):
chat_id
=
self
.
chats
.
get_chats
(
self
.
db_session
)[
0
].
id
chat_id
=
self
.
chats
.
get_chats
()[
0
].
id
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
f
"/
{
chat_id
}
"
),
...
...
@@ -181,14 +178,14 @@ class TestChats(AbstractPostgresTest):
assert
data
[
"user_id"
]
==
"2"
def
test_delete_chat_by_id
(
self
):
chat_id
=
self
.
chats
.
get_chats
(
self
.
db_session
)[
0
].
id
chat_id
=
self
.
chats
.
get_chats
()[
0
].
id
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
delete
(
self
.
create_url
(
f
"/
{
chat_id
}
"
))
assert
response
.
status_code
==
200
assert
response
.
json
()
is
True
def
test_clone_chat_by_id
(
self
):
chat_id
=
self
.
chats
.
get_chats
(
self
.
db_session
)[
0
].
id
chat_id
=
self
.
chats
.
get_chats
()[
0
].
id
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
f
"/
{
chat_id
}
/clone"
))
...
...
@@ -209,31 +206,30 @@ class TestChats(AbstractPostgresTest):
assert
data
[
"user_id"
]
==
"2"
def
test_archive_chat_by_id
(
self
):
chat_id
=
self
.
chats
.
get_chats
(
self
.
db_session
)[
0
].
id
chat_id
=
self
.
chats
.
get_chats
()[
0
].
id
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
f
"/
{
chat_id
}
/archive"
))
assert
response
.
status_code
==
200
chat
=
self
.
chats
.
get_chat_by_id
(
self
.
db_session
,
chat_id
)
chat
=
self
.
chats
.
get_chat_by_id
(
chat_id
)
assert
chat
.
archived
is
True
def
test_share_chat_by_id
(
self
):
chat_id
=
self
.
chats
.
get_chats
(
self
.
db_session
)[
0
].
id
chat_id
=
self
.
chats
.
get_chats
()[
0
].
id
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
f
"/
{
chat_id
}
/share"
))
assert
response
.
status_code
==
200
chat
=
self
.
chats
.
get_chat_by_id
(
self
.
db_session
,
chat_id
)
chat
=
self
.
chats
.
get_chat_by_id
(
chat_id
)
assert
chat
.
share_id
is
not
None
def
test_delete_shared_chat_by_id
(
self
):
chat_id
=
self
.
chats
.
get_chats
(
self
.
db_session
)[
0
].
id
chat_id
=
self
.
chats
.
get_chats
()[
0
].
id
share_id
=
str
(
uuid
.
uuid4
())
self
.
chats
.
update_chat_share_id_by_id
(
self
.
db_session
,
chat_id
,
share_id
)
self
.
db_session
.
commit
()
self
.
chats
.
update_chat_share_id_by_id
(
chat_id
,
share_id
)
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
delete
(
self
.
create_url
(
f
"/
{
chat_id
}
/share"
))
assert
response
.
status_code
chat
=
self
.
chats
.
get_chat_by_id
(
self
.
db_session
,
chat_id
)
chat
=
self
.
chats
.
get_chat_by_id
(
chat_id
)
assert
chat
.
share_id
is
None
backend/test/apps/webui/routers/test_documents.py
View file @
bee835cb
...
...
@@ -14,7 +14,7 @@ class TestDocuments(AbstractPostgresTest):
def
test_documents
(
self
):
# Empty database
assert
len
(
self
.
documents
.
get_docs
(
self
.
db_session
))
==
0
assert
len
(
self
.
documents
.
get_docs
())
==
0
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/"
))
assert
response
.
status_code
==
200
...
...
@@ -34,7 +34,7 @@ class TestDocuments(AbstractPostgresTest):
)
assert
response
.
status_code
==
200
assert
response
.
json
()[
"name"
]
==
"doc_name"
assert
len
(
self
.
documents
.
get_docs
(
self
.
db_session
))
==
1
assert
len
(
self
.
documents
.
get_docs
())
==
1
# Get the document
with
mock_webui_user
(
id
=
"2"
):
...
...
@@ -61,7 +61,7 @@ class TestDocuments(AbstractPostgresTest):
)
assert
response
.
status_code
==
200
assert
response
.
json
()[
"name"
]
==
"doc_name 2"
assert
len
(
self
.
documents
.
get_docs
(
self
.
db_session
))
==
2
assert
len
(
self
.
documents
.
get_docs
())
==
2
# Get all documents
with
mock_webui_user
(
id
=
"2"
):
...
...
@@ -95,7 +95,7 @@ class TestDocuments(AbstractPostgresTest):
assert
data
[
"content"
]
==
{
"tags"
:
[{
"name"
:
"testing-tag"
},
{
"name"
:
"another-tag"
}]
}
assert
len
(
self
.
documents
.
get_docs
(
self
.
db_session
))
==
2
assert
len
(
self
.
documents
.
get_docs
())
==
2
# Delete the first document
with
mock_webui_user
(
id
=
"2"
):
...
...
@@ -103,4 +103,4 @@ class TestDocuments(AbstractPostgresTest):
self
.
create_url
(
"/doc/delete?name=doc_name rework"
)
)
assert
response
.
status_code
==
200
assert
len
(
self
.
documents
.
get_docs
(
self
.
db_session
))
==
1
assert
len
(
self
.
documents
.
get_docs
())
==
1
backend/test/apps/webui/routers/test_prompts.py
View file @
bee835cb
...
...
@@ -68,6 +68,16 @@ class TestPrompts(AbstractPostgresTest):
assert
data
[
"content"
]
==
"description Updated"
assert
data
[
"user_id"
]
==
"3"
# Get prompt by command
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/command/my-command2"
))
assert
response
.
status_code
==
200
data
=
response
.
json
()
assert
data
[
"command"
]
==
"/my-command2"
assert
data
[
"title"
]
==
"Hello World Updated"
assert
data
[
"content"
]
==
"description Updated"
assert
data
[
"user_id"
]
==
"3"
# Delete prompt
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
delete
(
...
...
backend/test/apps/webui/routers/test_users.py
View file @
bee835cb
...
...
@@ -33,7 +33,6 @@ class TestUsers(AbstractPostgresTest):
def
setup_method
(
self
):
super
().
setup_method
()
self
.
users
.
insert_new_user
(
self
.
db_session
,
id
=
"1"
,
name
=
"user 1"
,
email
=
"user1@openwebui.com"
,
...
...
@@ -41,7 +40,6 @@ class TestUsers(AbstractPostgresTest):
role
=
"user"
,
)
self
.
users
.
insert_new_user
(
self
.
db_session
,
id
=
"2"
,
name
=
"user 2"
,
email
=
"user2@openwebui.com"
,
...
...
backend/utils/utils.py
View file @
bee835cb
...
...
@@ -2,7 +2,6 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from
fastapi
import
HTTPException
,
status
,
Depends
,
Request
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.users
import
Users
from
pydantic
import
BaseModel
...
...
@@ -79,7 +78,6 @@ def get_http_authorization_cred(auth_header: str):
def
get_current_user
(
request
:
Request
,
auth_token
:
HTTPAuthorizationCredentials
=
Depends
(
bearer_security
),
db
=
Depends
(
get_db
),
):
token
=
None
...
...
@@ -94,19 +92,19 @@ def get_current_user(
# auth by api key
if
token
.
startswith
(
"sk-"
):
return
get_current_user_by_api_key
(
db
,
token
)
return
get_current_user_by_api_key
(
token
)
# auth by jwt token
data
=
decode_token
(
token
)
if
data
!=
None
and
"id"
in
data
:
user
=
Users
.
get_user_by_id
(
db
,
data
[
"id"
])
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
if
user
is
None
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
INVALID_TOKEN
,
)
else
:
Users
.
update_user_last_active_by_id
(
db
,
user
.
id
)
Users
.
update_user_last_active_by_id
(
user
.
id
)
return
user
else
:
raise
HTTPException
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment