Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
9bcd4ce5
Unverified
Commit
9bcd4ce5
authored
Jul 09, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Jul 09, 2024
Browse files
Merge pull request #3559 from open-webui/dev
0.3.8
parents
824966ad
b38abf23
Changes
178
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1219 additions
and
844 deletions
+1219
-844
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+53
-42
backend/apps/webui/models/prompts.py
backend/apps/webui/models/prompts.py
+47
-46
backend/apps/webui/models/tags.py
backend/apps/webui/models/tags.py
+141
-106
backend/apps/webui/models/tools.py
backend/apps/webui/models/tools.py
+75
-66
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+118
-94
backend/apps/webui/routers/chats.py
backend/apps/webui/routers/chats.py
+5
-3
backend/apps/webui/routers/documents.py
backend/apps/webui/routers/documents.py
+3
-1
backend/apps/webui/routers/files.py
backend/apps/webui/routers/files.py
+1
-4
backend/apps/webui/routers/functions.py
backend/apps/webui/routers/functions.py
+4
-1
backend/apps/webui/routers/memories.py
backend/apps/webui/routers/memories.py
+3
-1
backend/apps/webui/routers/models.py
backend/apps/webui/routers/models.py
+8
-2
backend/apps/webui/routers/prompts.py
backend/apps/webui/routers/prompts.py
+3
-1
backend/apps/webui/routers/tools.py
backend/apps/webui/routers/tools.py
+7
-3
backend/apps/webui/routers/users.py
backend/apps/webui/routers/users.py
+4
-2
backend/apps/webui/routers/utils.py
backend/apps/webui/routers/utils.py
+4
-4
backend/config.py
backend/config.py
+67
-10
backend/constants.py
backend/constants.py
+11
-0
backend/main.py
backend/main.py
+565
-458
backend/migrations/README
backend/migrations/README
+4
-0
backend/migrations/env.py
backend/migrations/env.py
+96
-0
No files found.
backend/apps/webui/models/models.py
View file @
9bcd4ce5
...
...
@@ -2,13 +2,10 @@ import json
import
logging
from
typing
import
Optional
import
peewee
as
pw
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
,
get_db
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
...
...
@@ -32,7 +29,7 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table
class
ModelMeta
(
BaseModel
):
profile_image_url
:
Optional
[
str
]
=
"/favicon.png"
profile_image_url
:
Optional
[
str
]
=
"/
static/
favicon.png"
description
:
Optional
[
str
]
=
None
"""
...
...
@@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
pass
class
Model
(
pw
.
Model
):
id
=
pw
.
TextField
(
unique
=
True
)
class
Model
(
Base
):
__tablename__
=
"model"
id
=
Column
(
Text
,
primary_key
=
True
)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id
=
pw
.
TextField
(
)
user_id
=
Column
(
Text
)
base_model_id
=
pw
.
TextField
(
null
=
True
)
base_model_id
=
Column
(
Text
,
nullable
=
True
)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name
=
pw
.
TextField
(
)
name
=
Column
(
Text
)
"""
The human-readable display name of the model.
"""
params
=
JSONField
(
)
params
=
Column
(
JSONField
)
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta
=
JSONField
(
)
meta
=
Column
(
JSONField
)
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Meta
:
database
=
DB
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
ModelModel
(
BaseModel
):
...
...
@@ -92,6 +88,8 @@ class ModelModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -115,12 +113,6 @@ class ModelForm(BaseModel):
class
ModelsTable
:
def
__init__
(
self
,
db
:
pw
.
SqliteDatabase
|
pw
.
PostgresqlDatabase
,
):
self
.
db
=
db
self
.
db
.
create_tables
([
Model
])
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
...
...
@@ -134,34 +126,50 @@ class ModelsTable:
}
)
try
:
result
=
Model
.
create
(
**
model
.
model_dump
())
if
result
:
return
model
else
:
return
None
with
get_db
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
ModelModel
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
print
(
e
)
return
None
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
return
[
ModelModel
(
**
model_to_dict
(
model
))
for
model
in
Model
.
select
()]
with
get_db
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_to_dict
(
model
))
with
get_db
()
as
db
:
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
except
:
return
None
def
update_model_by_id
(
self
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
try
:
# update only the fields that are present in the model
query
=
Model
.
update
(
**
model
.
model_dump
()).
where
(
Model
.
id
==
id
)
query
.
execute
()
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_to_dict
(
model
))
with
get_db
()
as
db
:
# update only the fields that are present in the model
result
=
(
db
.
query
(
Model
)
.
filter_by
(
id
=
id
)
.
update
(
model
.
model_dump
(
exclude
=
{
"id"
},
exclude_none
=
True
))
)
db
.
commit
()
model
=
db
.
get
(
Model
,
id
)
db
.
refresh
(
model
)
return
ModelModel
.
model_validate
(
model
)
except
Exception
as
e
:
print
(
e
)
...
...
@@ -169,11 +177,14 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Model
.
delete
().
where
(
Model
.
id
==
id
)
query
.
execute
()
return
True
with
get_db
()
as
db
:
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
except
:
return
False
Models
=
ModelsTable
(
DB
)
Models
=
ModelsTable
()
backend/apps/webui/models/prompts.py
View file @
9bcd4ce5
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
D
B
from
apps.webui.internal.db
import
B
ase
,
get_db
import
json
...
...
@@ -16,15 +13,14 @@ import json
####################
class
Prompt
(
Model
):
command
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
title
=
TextField
()
content
=
TextField
()
timestamp
=
BigIntegerField
()
class
Prompt
(
Base
):
__tablename__
=
"prompt"
class
Meta
:
database
=
DB
command
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
title
=
Column
(
Text
)
content
=
Column
(
Text
)
timestamp
=
Column
(
BigInteger
)
class
PromptModel
(
BaseModel
):
...
...
@@ -34,6 +30,8 @@ class PromptModel(BaseModel):
content
:
str
timestamp
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -48,10 +46,6 @@ class PromptForm(BaseModel):
class
PromptsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Prompt
])
def
insert_new_prompt
(
self
,
user_id
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
...
...
@@ -66,53 +60,60 @@ class PromptsTable:
)
try
:
result
=
Prompt
.
create
(
**
prompt
.
model_dump
())
if
result
:
return
prompt
else
:
return
None
except
:
with
get_db
()
as
db
:
result
=
Prompt
(
**
prompt
.
dict
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
PromptModel
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
return
None
def
get_prompt_by_command
(
self
,
command
:
str
)
->
Optional
[
PromptModel
]:
try
:
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
return
PromptModel
(
**
model_to_dict
(
prompt
))
with
get_db
()
as
db
:
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
return
PromptModel
.
model_validate
(
prompt
)
except
:
return
None
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
return
[
PromptModel
(
**
model_to_dict
(
prompt
))
for
prompt
in
Prompt
.
select
()
# .limit(limit).offset(skip
)
]
with
get_db
()
as
db
:
return
[
PromptModel
.
model_validate
(
prompt
)
for
prompt
in
db
.
query
(
Prompt
).
all
(
)
]
def
update_prompt_by_command
(
self
,
command
:
str
,
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
try
:
query
=
Prompt
.
update
(
title
=
form_data
.
title
,
content
=
form_data
.
content
,
timestamp
=
int
(
time
.
time
()),
).
where
(
Prompt
.
command
==
command
)
query
.
execute
()
prompt
=
Prompt
.
get
(
Prompt
.
command
==
command
)
return
PromptModel
(
**
model_to_dict
(
prompt
))
with
get_db
()
as
db
:
prompt
=
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
first
()
prompt
.
title
=
form_data
.
title
prompt
.
content
=
form_data
.
content
prompt
.
timestamp
=
int
(
time
.
time
())
db
.
commit
()
return
PromptModel
.
model_validate
(
prompt
)
except
:
return
None
def
delete_prompt_by_command
(
self
,
command
:
str
)
->
bool
:
try
:
query
=
Prompt
.
delete
().
where
((
Prompt
.
command
==
command
))
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
db
.
query
(
Prompt
).
filter_by
(
command
=
command
).
delete
()
db
.
commit
()
return
True
return
True
except
:
return
False
Prompts
=
PromptsTable
(
DB
)
Prompts
=
PromptsTable
()
backend/apps/webui/models/tags.py
View file @
9bcd4ce5
from
pydantic
import
BaseModel
from
typing
import
List
,
Union
,
Optional
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
json
import
uuid
import
time
import
logging
from
apps.webui.internal.db
import
DB
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
get_db
from
config
import
SRC_LOG_LEVELS
...
...
@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Tag
(
Model
):
id
=
CharField
(
unique
=
True
)
name
=
CharField
()
user_id
=
CharField
()
data
=
TextField
(
null
=
True
)
class
Tag
(
Base
):
__tablename__
=
"tag"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
)
user_id
=
Column
(
String
)
data
=
Column
(
Text
,
nullable
=
True
)
class
ChatIdTag
(
Model
):
id
=
CharField
(
unique
=
True
)
tag_name
=
CharField
()
chat_id
=
CharField
()
user_id
=
CharField
()
timestamp
=
BigIntegerField
()
class
ChatIdTag
(
Base
):
__tablename__
=
"chatidtag"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
tag_name
=
Column
(
String
)
chat_id
=
Column
(
String
)
user_id
=
Column
(
String
)
timestamp
=
Column
(
BigInteger
)
class
TagModel
(
BaseModel
):
...
...
@@ -47,6 +45,8 @@ class TagModel(BaseModel):
user_id
:
str
data
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
class
ChatIdTagModel
(
BaseModel
):
id
:
str
...
...
@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
user_id
:
str
timestamp
:
int
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -75,28 +77,31 @@ class ChatTagsResponse(BaseModel):
class
TagTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
db
.
create_tables
([
Tag
,
ChatIdTag
])
def
insert_new_tag
(
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
id
=
str
(
uuid
.
uuid4
())
tag
=
TagModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"name"
:
name
})
try
:
result
=
Tag
.
create
(
**
tag
.
model_dump
())
if
result
:
return
tag
else
:
with
get_db
()
as
db
:
id
=
str
(
uuid
.
uuid4
())
tag
=
TagModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"name"
:
name
})
try
:
result
=
Tag
(
**
tag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
TagModel
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
return
None
except
Exception
as
e
:
return
None
def
get_tag_by_name_and_user_id
(
self
,
name
:
str
,
user_id
:
str
)
->
Optional
[
TagModel
]:
try
:
tag
=
Tag
.
get
(
Tag
.
name
==
name
,
Tag
.
user_id
==
user_id
)
return
TagModel
(
**
model_to_dict
(
tag
))
with
get_db
()
as
db
:
tag
=
db
.
query
(
Tag
).
filter
(
name
=
name
,
user_id
=
user_id
).
first
()
return
TagModel
.
model_validate
(
tag
)
except
Exception
as
e
:
return
None
...
...
@@ -118,82 +123,110 @@ class TagTable:
}
)
try
:
result
=
ChatIdTag
.
create
(
**
chatIdTag
.
model_dump
())
if
result
:
return
chatIdTag
else
:
return
None
with
get_db
()
as
db
:
result
=
ChatIdTag
(
**
chatIdTag
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
ChatIdTagModel
.
model_validate
(
result
)
else
:
return
None
except
:
return
None
def
get_tags_by_user_id
(
self
,
user_id
:
str
)
->
List
[
TagModel
]:
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
(
ChatIdTag
.
user_id
==
user_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
]
return
[
TagModel
(
**
model_to_dict
(
tag
))
for
tag
in
Tag
.
select
()
.
where
(
Tag
.
user_id
==
user_id
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
]
with
get_db
()
as
db
:
tag_names
=
[
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
return
[
TagModel
.
model_validate
(
tag
)
for
tag
in
(
db
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
def
get_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
List
[
TagModel
]:
tag_names
=
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
)).
tag_name
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
chat_id
==
chat_id
))
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
]
return
[
TagModel
(
**
model_to_dict
(
tag
))
for
tag
in
Tag
.
select
()
.
where
(
Tag
.
user_id
==
user_id
)
.
where
(
Tag
.
name
.
in_
(
tag_names
))
]
with
get_db
()
as
db
:
tag_names
=
[
chat_id_tag
.
tag_name
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
chat_id
=
chat_id
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
return
[
TagModel
.
model_validate
(
tag
)
for
tag
in
(
db
.
query
(
Tag
)
.
filter_by
(
user_id
=
user_id
)
.
filter
(
Tag
.
name
.
in_
(
tag_names
))
.
all
()
)
]
def
get_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
Optional
[
ChatIdTagModel
]:
return
[
ChatIdTagModel
(
**
model_to_dict
(
chat_id_tag
))
for
chat_id_tag
in
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
tag_name
==
tag_name
))
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
]
)
->
List
[
ChatIdTagModel
]:
with
get_db
()
as
db
:
return
[
ChatIdTagModel
.
model_validate
(
chat_id_tag
)
for
chat_id_tag
in
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
user_id
=
user_id
,
tag_name
=
tag_name
)
.
order_by
(
ChatIdTag
.
timestamp
.
desc
())
.
all
()
)
]
def
count_chat_ids_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
int
:
return
(
ChatIdTag
.
select
()
.
where
((
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
))
.
count
()
)
with
get_db
()
as
db
:
return
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
count
()
)
def
delete_tag_by_tag_name_and_user_id
(
self
,
tag_name
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
with
get_db
()
as
db
:
res
=
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
user_id
=
user_id
)
.
delete
()
)
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
return
True
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
db
.
commit
()
return
True
except
Exception
as
e
:
log
.
error
(
f
"delete_tag:
{
e
}
"
)
return
False
...
...
@@ -202,23 +235,25 @@ class TagTable:
self
,
tag_name
:
str
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
ChatIdTag
.
delete
().
where
(
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
chat_id
==
chat_id
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
# Remove tag item from Tag col as well
query
=
Tag
.
delete
().
where
(
(
Tag
.
name
==
tag_name
)
&
(
Tag
.
user_id
==
user_id
)
with
get_db
()
as
db
:
res
=
(
db
.
query
(
ChatIdTag
)
.
filter_by
(
tag_name
=
tag_name
,
chat_id
=
chat_id
,
user_id
=
user_id
)
.
delete
()
)
log
.
debug
(
f
"res:
{
res
}
"
)
db
.
commit
()
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
if
tag_count
==
0
:
# Remove tag item from Tag col as well
db
.
query
(
Tag
).
filter_by
(
name
=
tag_name
,
user_id
=
user_id
).
delete
()
db
.
commit
()
return
True
return
True
except
Exception
as
e
:
log
.
error
(
f
"delete_tag:
{
e
}
"
)
return
False
...
...
@@ -234,4 +269,4 @@ class TagTable:
return
True
Tags
=
TagTable
(
DB
)
Tags
=
TagTable
()
backend/apps/webui/models/tools.py
View file @
9bcd4ce5
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
List
,
Optional
import
time
import
logging
from
apps.webui.internal.db
import
DB
,
JSONField
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
apps.webui.models.users
import
Users
import
json
...
...
@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class
Tool
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
name
=
TextField
()
content
=
TextField
()
specs
=
JSONField
()
meta
=
JSONField
()
valves
=
JSONField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Tool
(
Base
):
__tablename__
=
"tool"
class
Meta
:
database
=
DB
id
=
Column
(
String
,
primary_key
=
True
)
user_id
=
Column
(
String
)
name
=
Column
(
Text
)
content
=
Column
(
Text
)
specs
=
Column
(
JSONField
)
meta
=
Column
(
JSONField
)
valves
=
Column
(
JSONField
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
class
ToolMeta
(
BaseModel
):
...
...
@@ -51,6 +50,8 @@ class ToolModel(BaseModel):
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -78,61 +79,68 @@ class ToolValves(BaseModel):
class
ToolsTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Tool
])
def
insert_new_tool
(
self
,
user_id
:
str
,
form_data
:
ToolForm
,
specs
:
List
[
dict
]
)
->
Optional
[
ToolModel
]:
tool
=
ToolModel
(
**
{
**
form_data
.
model_dump
(),
"specs"
:
specs
,
"user_id"
:
user_id
,
"updated_at"
:
int
(
time
.
time
()),
"created_at"
:
int
(
time
.
time
()),
}
)
try
:
result
=
Tool
.
create
(
**
tool
.
model_dump
())
if
result
:
return
tool
else
:
with
get_db
()
as
db
:
tool
=
ToolModel
(
**
{
**
form_data
.
model_dump
(),
"specs"
:
specs
,
"user_id"
:
user_id
,
"updated_at"
:
int
(
time
.
time
()),
"created_at"
:
int
(
time
.
time
()),
}
)
try
:
result
=
Tool
(
**
tool
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
ToolModel
.
model_validate
(
result
)
else
:
return
None
except
Exception
as
e
:
print
(
f
"Error creating tool:
{
e
}
"
)
return
None
except
Exception
as
e
:
print
(
f
"Error creating tool:
{
e
}
"
)
return
None
def
get_tool_by_id
(
self
,
id
:
str
)
->
Optional
[
ToolModel
]:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolModel
(
**
model_to_dict
(
tool
))
with
get_db
()
as
db
:
tool
=
db
.
get
(
Tool
,
id
)
return
ToolModel
.
model_validate
(
tool
)
except
:
return
None
def
get_tools
(
self
)
->
List
[
ToolModel
]:
return
[
ToolModel
(
**
model_to_dict
(
tool
))
for
tool
in
Tool
.
select
()]
with
get_db
()
as
db
:
return
[
ToolModel
.
model_validate
(
tool
)
for
tool
in
db
.
query
(
Tool
).
all
()]
def
get_tool_valves_by_id
(
self
,
id
:
str
)
->
Optional
[
dict
]:
try
:
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
tool
.
valves
if
tool
.
valves
else
{}
with
get_db
()
as
db
:
tool
=
db
.
get
(
Tool
,
id
)
return
tool
.
valves
if
tool
.
valves
else
{}
except
Exception
as
e
:
print
(
f
"An error occurred:
{
e
}
"
)
return
None
def
update_tool_valves_by_id
(
self
,
id
:
str
,
valves
:
dict
)
->
Optional
[
ToolValves
]:
try
:
query
=
Tool
.
update
(
**
{
"valves"
:
valves
},
updated_at
=
int
(
time
.
time
()),
).
where
(
Tool
.
id
==
id
)
query
.
execute
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolValves
(
**
model_to_dict
(
tool
))
with
get_db
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
"valves"
:
valves
,
"updated_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
return
self
.
get_tool_by_id
(
id
)
except
:
return
None
...
...
@@ -141,7 +149,7 @@ class ToolsTable:
)
->
Optional
[
dict
]:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "tools" and "valves" settings
if
"tools"
not
in
user_settings
:
...
...
@@ -159,7 +167,7 @@ class ToolsTable:
)
->
Optional
[
dict
]:
try
:
user
=
Users
.
get_user_by_id
(
user_id
)
user_settings
=
user
.
settings
.
model_dump
()
user_settings
=
user
.
settings
.
model_dump
()
if
user
.
settings
else
{}
# Check if user has "tools" and "valves" settings
if
"tools"
not
in
user_settings
:
...
...
@@ -170,8 +178,7 @@ class ToolsTable:
user_settings
[
"tools"
][
"valves"
][
id
]
=
valves
# Update the user settings in the database
query
=
Users
.
update_user_by_id
(
user_id
,
{
"settings"
:
user_settings
})
query
.
execute
()
Users
.
update_user_by_id
(
user_id
,
{
"settings"
:
user_settings
})
return
user_settings
[
"tools"
][
"valves"
][
id
]
except
Exception
as
e
:
...
...
@@ -180,25 +187,27 @@ class ToolsTable:
def
update_tool_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
ToolModel
]:
try
:
query
=
Tool
.
update
(
**
updated
,
updated_at
=
int
(
time
.
time
()),
).
where
(
Tool
.
id
==
id
)
query
.
execute
()
tool
=
Tool
.
get
(
Tool
.
id
==
id
)
return
ToolModel
(
**
model_to_dict
(
tool
))
with
get_db
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
update
(
{
**
updated
,
"updated_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
tool
=
db
.
query
(
Tool
).
get
(
id
)
db
.
refresh
(
tool
)
return
ToolModel
.
model_validate
(
tool
)
except
:
return
None
def
delete_tool_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Tool
.
delete
().
where
((
Tool
.
id
==
id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
db
.
query
(
Tool
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
return
True
except
:
return
False
Tools
=
ToolsTable
(
DB
)
Tools
=
ToolsTable
()
backend/apps/webui/models/users.py
View file @
9bcd4ce5
from
pydantic
import
BaseModel
,
ConfigDict
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
,
parse_obj_as
from
typing
import
List
,
Union
,
Optional
import
time
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
utils.misc
import
get_gravatar_url
from
apps.webui.internal.db
import
D
B
,
JSONField
from
apps.webui.internal.db
import
B
ase
,
JSONField
,
Session
,
get_db
from
apps.webui.models.chats
import
Chats
####################
...
...
@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
####################
class
User
(
Model
):
id
=
CharField
(
unique
=
True
)
name
=
CharField
()
email
=
CharField
()
role
=
CharField
()
profile_image_url
=
TextField
()
class
User
(
Base
):
__tablename__
=
"user"
last_active_at
=
BigIntegerField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
id
=
Column
(
String
,
primary_key
=
True
)
name
=
Column
(
String
)
email
=
Column
(
String
)
role
=
Column
(
String
)
profile_image_url
=
Column
(
Text
)
api_key
=
CharField
(
null
=
True
,
unique
=
True
)
settings
=
JSONField
(
null
=
True
)
info
=
JSONField
(
null
=
True
)
last_active_at
=
Column
(
BigInteger
)
updated_at
=
Column
(
BigInteger
)
created_at
=
Column
(
BigInteger
)
oauth_sub
=
TextField
(
null
=
True
,
unique
=
True
)
api_key
=
Column
(
String
,
nullable
=
True
,
unique
=
True
)
settings
=
Column
(
JSONField
,
nullable
=
True
)
info
=
Column
(
JSONField
,
nullable
=
True
)
class
Meta
:
database
=
DB
oauth_sub
=
Column
(
Text
,
unique
=
True
)
class
UserSettings
(
BaseModel
):
...
...
@@ -57,6 +57,8 @@ class UserModel(BaseModel):
oauth_sub
:
Optional
[
str
]
=
None
model_config
=
ConfigDict
(
from_attributes
=
True
)
####################
# Forms
...
...
@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
class
UsersTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
User
])
def
insert_new_user
(
self
,
...
...
@@ -89,77 +88,92 @@ class UsersTable:
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
user
=
UserModel
(
**
{
"id"
:
id
,
"name"
:
name
,
"email"
:
email
,
"role"
:
role
,
"profile_image_url"
:
profile_image_url
,
"last_active_at"
:
int
(
time
.
time
()),
"created_at"
:
int
(
time
.
time
()),
"updated_at"
:
int
(
time
.
time
()),
"oauth_sub"
:
oauth_sub
,
}
)
result
=
User
.
create
(
**
user
.
model_dump
())
if
result
:
return
user
else
:
return
None
with
get_db
()
as
db
:
user
=
UserModel
(
**
{
"id"
:
id
,
"name"
:
name
,
"email"
:
email
,
"role"
:
role
,
"profile_image_url"
:
profile_image_url
,
"last_active_at"
:
int
(
time
.
time
()),
"created_at"
:
int
(
time
.
time
()),
"updated_at"
:
int
(
time
.
time
()),
"oauth_sub"
:
oauth_sub
,
}
)
result
=
User
(
**
user
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
db
.
refresh
(
result
)
if
result
:
return
user
else
:
return
None
def
get_user_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
Exception
as
e
:
return
None
def
get_user_by_api_key
(
self
,
api_key
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
api_key
==
api_key
)
return
UserModel
(
**
model_to_dict
(
user
))
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
api_key
=
api_key
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
email
==
email
)
return
UserModel
(
**
model_to_dict
(
user
))
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
email
=
email
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_user_by_oauth_sub
(
self
,
sub
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
oauth_sub
==
sub
)
return
UserModel
(
**
model_to_dict
(
user
))
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
oauth_sub
=
sub
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
get_users
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
return
[
UserModel
(
**
model_to_dict
(
user
))
for
user
in
User
.
select
()
# .limit(limit).offset(skip)
]
with
get_db
()
as
db
:
users
=
(
db
.
query
(
User
)
# .offset(skip).limit(limit)
.
all
()
)
return
[
UserModel
.
model_validate
(
user
)
for
user
in
users
]
def
get_num_users
(
self
)
->
Optional
[
int
]:
return
User
.
select
().
count
()
with
get_db
()
as
db
:
return
db
.
query
(
User
).
count
()
def
get_first_user
(
self
)
->
UserModel
:
try
:
user
=
User
.
select
().
order_by
(
User
.
created_at
).
first
()
return
UserModel
(
**
model_to_dict
(
user
))
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
order_by
(
User
.
created_at
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
update_user_role_by_id
(
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
role
=
role
).
where
(
User
.
id
==
id
)
query
.
execute
(
)
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
with
get_db
()
as
db
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"role"
:
role
}
)
db
.
commit
()
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
...
...
@@ -167,23 +181,28 @@ class UsersTable:
self
,
id
:
str
,
profile_image_url
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
profile_image_url
=
profile_image_url
).
where
(
User
.
id
==
id
)
query
.
execute
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
with
get_db
()
as
db
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"profile_image_url"
:
profile_image_url
}
)
db
.
commit
()
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
def
update_user_last_active_by_id
(
self
,
id
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
last_active_at
=
int
(
time
.
time
())).
where
(
User
.
id
==
id
)
query
.
execute
()
with
get_db
()
as
db
:
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
{
"last_active_at"
:
int
(
time
.
time
())}
)
db
.
commit
()
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
except
:
return
None
...
...
@@ -191,22 +210,25 @@ class UsersTable:
self
,
id
:
str
,
oauth_sub
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
oauth_sub
=
oauth_sub
).
where
(
User
.
id
==
id
)
query
.
execute
()
with
get_db
()
as
db
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"oauth_sub"
:
oauth_sub
})
db
.
commit
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_
to_dict
(
user
)
)
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
(
)
return
UserModel
.
model_
validate
(
user
)
except
:
return
None
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
**
updated
).
where
(
User
.
id
==
id
)
query
.
execute
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
with
get_db
()
as
db
:
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
(
updated
)
db
.
commit
()
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
UserModel
.
model_validate
(
user
)
# return UserModel(**user.dict())
except
Exception
as
e
:
return
None
def
delete_user_by_id
(
self
,
id
:
str
)
->
bool
:
...
...
@@ -215,9 +237,10 @@ class UsersTable:
result
=
Chats
.
delete_chats_by_user_id
(
id
)
if
result
:
# Delete User
query
=
User
.
delete
().
where
(
User
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
with
get_db
()
as
db
:
# Delete User
db
.
query
(
User
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
return
True
else
:
...
...
@@ -227,19 +250,20 @@ class UsersTable:
def
update_user_api_key_by_id
(
self
,
id
:
str
,
api_key
:
str
)
->
str
:
try
:
query
=
User
.
update
(
api_key
=
api_key
).
where
(
User
.
id
==
id
)
result
=
query
.
execute
(
)
return
True
if
result
==
1
else
False
with
get_db
()
as
db
:
result
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
update
({
"api_key"
:
api_key
}
)
db
.
commit
()
return
True
if
result
==
1
else
False
except
:
return
False
def
get_user_api_key_by_id
(
self
,
id
:
str
)
->
Optional
[
str
]:
try
:
user
=
User
.
get
(
User
.
id
==
id
)
return
user
.
api_key
except
:
with
get_db
()
as
db
:
user
=
db
.
query
(
User
).
filter_by
(
id
=
id
).
first
()
return
user
.
api_key
except
Exception
as
e
:
return
None
Users
=
UsersTable
(
DB
)
Users
=
UsersTable
()
backend/apps/webui/routers/chats.py
View file @
9bcd4ce5
...
...
@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@
router
.
get
(
"/list/user/{user_id}"
,
response_model
=
List
[
ChatTitleIdResponse
])
async
def
get_user_chat_list_by_user_id
(
user_id
:
str
,
user
=
Depends
(
get_admin_user
),
skip
:
int
=
0
,
limit
:
int
=
50
user_id
:
str
,
user
=
Depends
(
get_admin_user
),
skip
:
int
=
0
,
limit
:
int
=
50
,
):
return
Chats
.
get_chat_list_by_user_id
(
user_id
,
include_archived
=
True
,
skip
=
skip
,
limit
=
limit
...
...
@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
@
router
.
get
(
"/all/archived"
,
response_model
=
List
[
ChatResponse
])
async
def
get_user_chats
(
user
=
Depends
(
get_verified_user
)):
async
def
get_user_
archived_
chats
(
user
=
Depends
(
get_verified_user
)):
return
[
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
for
chat
in
Chats
.
get_archived_chats_by_user_id
(
user
.
id
)
...
...
@@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name(
form_data
:
TagNameForm
,
user
=
Depends
(
get_verified_user
)
):
print
(
form_data
)
chat_ids
=
[
chat_id_tag
.
chat_id
for
chat_id_tag
in
Tags
.
get_chat_ids_by_tag_name_and_user_id
(
...
...
backend/apps/webui/routers/documents.py
View file @
9bcd4ce5
...
...
@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
@
router
.
post
(
"/doc/update"
,
response_model
=
Optional
[
DocumentResponse
])
async
def
update_doc_by_name
(
name
:
str
,
form_data
:
DocumentUpdateForm
,
user
=
Depends
(
get_admin_user
)
name
:
str
,
form_data
:
DocumentUpdateForm
,
user
=
Depends
(
get_admin_user
),
):
doc
=
Documents
.
update_doc_by_name
(
name
,
form_data
)
if
doc
:
...
...
backend/apps/webui/routers/files.py
View file @
9bcd4ce5
...
...
@@ -50,10 +50,7 @@ router = APIRouter()
@
router
.
post
(
"/"
)
def
upload_file
(
file
:
UploadFile
=
File
(...),
user
=
Depends
(
get_verified_user
),
):
def
upload_file
(
file
:
UploadFile
=
File
(...),
user
=
Depends
(
get_verified_user
)):
log
.
info
(
f
"file.content_type:
{
file
.
content_type
}
"
)
try
:
unsanitized_filename
=
file
.
filename
...
...
backend/apps/webui/routers/functions.py
View file @
9bcd4ce5
...
...
@@ -233,7 +233,10 @@ async def delete_function_by_id(
# delete the function file
function_path
=
os
.
path
.
join
(
FUNCTIONS_DIR
,
f
"
{
id
}
.py"
)
os
.
remove
(
function_path
)
try
:
os
.
remove
(
function_path
)
except
:
pass
return
result
...
...
backend/apps/webui/routers/memories.py
View file @
9bcd4ce5
...
...
@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel):
@
router
.
post
(
"/add"
,
response_model
=
Optional
[
MemoryModel
])
async
def
add_memory
(
request
:
Request
,
form_data
:
AddMemoryForm
,
user
=
Depends
(
get_verified_user
)
request
:
Request
,
form_data
:
AddMemoryForm
,
user
=
Depends
(
get_verified_user
),
):
memory
=
Memories
.
insert_new_memory
(
user
.
id
,
form_data
.
content
)
memory_embedding
=
request
.
app
.
state
.
EMBEDDING_FUNCTION
(
memory
.
content
)
...
...
backend/apps/webui/routers/models.py
View file @
9bcd4ce5
...
...
@@ -5,6 +5,7 @@ from typing import List, Union, Optional
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.webui.models.models
import
Models
,
ModelModel
,
ModelForm
,
ModelResponse
from
utils.utils
import
get_verified_user
,
get_admin_user
...
...
@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)):
@
router
.
post
(
"/add"
,
response_model
=
Optional
[
ModelModel
])
async
def
add_new_model
(
request
:
Request
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
)
request
:
Request
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
),
):
if
form_data
.
id
in
request
.
app
.
state
.
MODELS
:
raise
HTTPException
(
...
...
@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@
router
.
post
(
"/update"
,
response_model
=
Optional
[
ModelModel
])
async
def
update_model_by_id
(
request
:
Request
,
id
:
str
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
)
request
:
Request
,
id
:
str
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
),
):
model
=
Models
.
get_model_by_id
(
id
)
if
model
:
...
...
backend/apps/webui/routers/prompts.py
View file @
9bcd4ce5
...
...
@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
@
router
.
post
(
"/command/{command}/update"
,
response_model
=
Optional
[
PromptModel
])
async
def
update_prompt_by_command
(
command
:
str
,
form_data
:
PromptForm
,
user
=
Depends
(
get_admin_user
)
command
:
str
,
form_data
:
PromptForm
,
user
=
Depends
(
get_admin_user
),
):
prompt
=
Prompts
.
update_prompt_by_command
(
f
"/
{
command
}
"
,
form_data
)
if
prompt
:
...
...
backend/apps/webui/routers/tools.py
View file @
9bcd4ce5
...
...
@@ -6,7 +6,6 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.webui.models.users
import
Users
from
apps.webui.models.tools
import
Tools
,
ToolForm
,
ToolModel
,
ToolResponse
from
apps.webui.utils
import
load_toolkit_module_by_id
...
...
@@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)):
@
router
.
post
(
"/create"
,
response_model
=
Optional
[
ToolResponse
])
async
def
create_new_toolkit
(
request
:
Request
,
form_data
:
ToolForm
,
user
=
Depends
(
get_admin_user
)
request
:
Request
,
form_data
:
ToolForm
,
user
=
Depends
(
get_admin_user
),
):
if
not
form_data
.
id
.
isidentifier
():
raise
HTTPException
(
...
...
@@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
@
router
.
post
(
"/id/{id}/update"
,
response_model
=
Optional
[
ToolModel
])
async
def
update_toolkit_by_id
(
request
:
Request
,
id
:
str
,
form_data
:
ToolForm
,
user
=
Depends
(
get_admin_user
)
request
:
Request
,
id
:
str
,
form_data
:
ToolForm
,
user
=
Depends
(
get_admin_user
),
):
toolkit_path
=
os
.
path
.
join
(
TOOLS_DIR
,
f
"
{
id
}
.py"
)
...
...
backend/apps/webui/routers/users.py
View file @
9bcd4ce5
...
...
@@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@
router
.
post
(
"/user/info/update"
,
response_model
=
Optional
[
dict
])
async
def
update_user_
settings
_by_session_user
(
async
def
update_user_
info
_by_session_user
(
form_data
:
dict
,
user
=
Depends
(
get_verified_user
)
):
user
=
Users
.
get_user_by_id
(
user
.
id
)
...
...
@@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@
router
.
post
(
"/{user_id}/update"
,
response_model
=
Optional
[
UserModel
])
async
def
update_user_by_id
(
user_id
:
str
,
form_data
:
UserUpdateForm
,
session_user
=
Depends
(
get_admin_user
)
user_id
:
str
,
form_data
:
UserUpdateForm
,
session_user
=
Depends
(
get_admin_user
),
):
user
=
Users
.
get_user_by_id
(
user_id
)
...
...
backend/apps/webui/routers/utils.py
View file @
9bcd4ce5
from
fastapi
import
APIRouter
,
UploadFile
,
File
,
Response
from
fastapi
import
Depends
,
HTTPException
,
status
from
peewee
import
SqliteDatabase
from
starlette.responses
import
StreamingResponse
,
FileResponse
from
pydantic
import
BaseModel
...
...
@@ -10,7 +9,6 @@ import markdown
import
black
from
apps.webui.internal.db
import
DB
from
utils.utils
import
get_admin_user
from
utils.misc
import
calculate_sha256
,
get_gravatar_url
...
...
@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
,
)
if
not
isinstance
(
DB
,
SqliteDatabase
):
from
apps.webui.internal.db
import
engine
if
engine
.
name
!=
"sqlite"
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DB_NOT_SQLITE
,
)
return
FileResponse
(
DB
.
database
,
engine
.
url
.
database
,
media_type
=
"application/octet-stream"
,
filename
=
"webui.db"
,
)
...
...
backend/config.py
View file @
9bcd4ce5
...
...
@@ -5,9 +5,8 @@ import importlib.metadata
import
pkgutil
import
chromadb
from
chromadb
import
Settings
from
base64
import
b64encode
from
bs4
import
BeautifulSoup
from
typing
import
TypeVar
,
Generic
,
Union
from
typing
import
TypeVar
,
Generic
from
pydantic
import
BaseModel
from
typing
import
Optional
...
...
@@ -19,7 +18,6 @@ import markdown
import
requests
import
shutil
from
secrets
import
token_bytes
from
constants
import
ERROR_MESSAGES
####################################
...
...
@@ -395,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
os
.
environ
.
get
(
"OAUTH_PROVIDER_NAME"
,
"SSO"
),
)
OAUTH_USERNAME_CLAIM
=
PersistentConfig
(
"OAUTH_USERNAME_CLAIM"
,
"oauth.oidc.username_claim"
,
os
.
environ
.
get
(
"OAUTH_USERNAME_CLAIM"
,
"name"
),
)
OAUTH_PICTURE_CLAIM
=
PersistentConfig
(
"OAUTH_USERNAME_CLAIM"
,
"oauth.oidc.avatar_claim"
,
os
.
environ
.
get
(
"OAUTH_PICTURE_CLAIM"
,
"picture"
),
)
def
load_oauth_providers
():
OAUTH_PROVIDERS
.
clear
()
...
...
@@ -440,16 +450,27 @@ load_oauth_providers()
STATIC_DIR
=
Path
(
os
.
getenv
(
"STATIC_DIR"
,
BACKEND_DIR
/
"static"
)).
resolve
()
frontend_favicon
=
FRONTEND_BUILD_DIR
/
"favicon.png"
frontend_favicon
=
FRONTEND_BUILD_DIR
/
"static"
/
"favicon.png"
if
frontend_favicon
.
exists
():
try
:
shutil
.
copyfile
(
frontend_favicon
,
STATIC_DIR
/
"favicon.png"
)
except
Exception
as
e
:
logging
.
error
(
f
"An error occurred:
{
e
}
"
)
else
:
logging
.
warning
(
f
"Frontend favicon not found at
{
frontend_favicon
}
"
)
frontend_splash
=
FRONTEND_BUILD_DIR
/
"static"
/
"splash.png"
if
frontend_splash
.
exists
():
try
:
shutil
.
copyfile
(
frontend_splash
,
STATIC_DIR
/
"splash.png"
)
except
Exception
as
e
:
logging
.
error
(
f
"An error occurred:
{
e
}
"
)
else
:
logging
.
warning
(
f
"Frontend splash not found at
{
frontend_splash
}
"
)
####################################
# CUSTOM_NAME
####################################
...
...
@@ -474,6 +495,19 @@ if CUSTOM_NAME:
r
.
raw
.
decode_content
=
True
shutil
.
copyfileobj
(
r
.
raw
,
f
)
if
"splash"
in
data
:
url
=
(
f
"https://api.openwebui.com
{
data
[
'splash'
]
}
"
if
data
[
"splash"
][
0
]
==
"/"
else
data
[
"splash"
]
)
r
=
requests
.
get
(
url
,
stream
=
True
)
if
r
.
status_code
==
200
:
with
open
(
f
"
{
STATIC_DIR
}
/splash.png"
,
"wb"
)
as
f
:
r
.
raw
.
decode_content
=
True
shutil
.
copyfileobj
(
r
.
raw
,
f
)
WEBUI_NAME
=
data
[
"name"
]
except
Exception
as
e
:
log
.
exception
(
e
)
...
...
@@ -769,11 +803,14 @@ class BannerModel(BaseModel):
timestamp
:
int
WEBUI_BANNERS
=
PersistentConfig
(
"WEBUI_BANNERS"
,
"ui.banners"
,
[
BannerModel
(
**
banner
)
for
banner
in
json
.
loads
(
"[]"
)],
)
try
:
banners
=
json
.
loads
(
os
.
environ
.
get
(
"WEBUI_BANNERS"
,
"[]"
))
banners
=
[
BannerModel
(
**
banner
)
for
banner
in
banners
]
except
Exception
as
e
:
print
(
f
"Error loading WEBUI_BANNERS:
{
e
}
"
)
banners
=
[]
WEBUI_BANNERS
=
PersistentConfig
(
"WEBUI_BANNERS"
,
"ui.banners"
,
banners
)
SHOW_ADMIN_DETAILS
=
PersistentConfig
(
...
...
@@ -885,6 +922,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
if
WEBUI_AUTH
and
WEBUI_SECRET_KEY
==
""
:
raise
ValueError
(
ERROR_MESSAGES
.
ENV_VAR_NOT_FOUND
)
####################################
# RAG document content extraction
####################################
CONTENT_EXTRACTION_ENGINE
=
PersistentConfig
(
"CONTENT_EXTRACTION_ENGINE"
,
"rag.CONTENT_EXTRACTION_ENGINE"
,
os
.
environ
.
get
(
"CONTENT_EXTRACTION_ENGINE"
,
""
).
lower
(),
)
TIKA_SERVER_URL
=
PersistentConfig
(
"TIKA_SERVER_URL"
,
"rag.tika_server_url"
,
os
.
getenv
(
"TIKA_SERVER_URL"
,
"http://tika:9998"
),
# Default for sidecar deployment
)
####################################
# RAG
####################################
...
...
@@ -1302,3 +1355,7 @@ AUDIO_TTS_VOICE = PersistentConfig(
####################################
DATABASE_URL
=
os
.
environ
.
get
(
"DATABASE_URL"
,
f
"sqlite:///
{
DATA_DIR
}
/webui.db"
)
# Replace the postgres:// with postgresql://
if
"postgres://"
in
DATABASE_URL
:
DATABASE_URL
=
DATABASE_URL
.
replace
(
"postgres://"
,
"postgresql://"
)
backend/constants.py
View file @
9bcd4ce5
...
...
@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
OLLAMA_API_DISABLED
=
(
"The Ollama API is disabled. Please enable it to use this feature."
)
class
TASKS
(
str
,
Enum
):
def
__str__
(
self
)
->
str
:
return
super
().
__str__
()
DEFAULT
=
lambda
task
=
""
:
f
"
{
task
if
task
else
'default'
}
"
TITLE_GENERATION
=
"Title Generation"
EMOJI_GENERATION
=
"Emoji Generation"
QUERY_GENERATION
=
"Query Generation"
FUNCTION_CALLING
=
"Function Calling"
backend/main.py
View file @
9bcd4ce5
...
...
@@ -4,9 +4,7 @@ from contextlib import asynccontextmanager
from
authlib.integrations.starlette_client
import
OAuth
from
authlib.oidc.core
import
UserInfo
from
bs4
import
BeautifulSoup
import
json
import
markdown
import
time
import
os
import
sys
...
...
@@ -18,25 +16,22 @@ import shutil
import
os
import
uuid
import
inspect
import
asyncio
from
fastapi.concurrency
import
run_in_threadpool
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
from
fastapi.staticfiles
import
StaticFiles
from
fastapi.responses
import
JSONResponse
from
fastapi
import
HTTPException
from
fastapi.middleware.wsgi
import
WSGIMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
sqlalchemy
import
text
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
apps.socket.main
import
app
as
socket_app
from
apps.socket.main
import
sio
,
app
as
socket_app
from
apps.ollama.main
import
(
app
as
ollama_app
,
OpenAIChatCompletionForm
,
get_all_models
as
get_ollama_models
,
generate_openai_chat_completion
as
generate_ollama_chat_completion
,
)
...
...
@@ -54,13 +49,14 @@ from apps.webui.main import (
get_pipe_models
,
generate_function_chat_completion
,
)
from
apps.webui.internal.db
import
Session
from
pydantic
import
BaseModel
from
typing
import
List
,
Optional
,
Iterator
,
Generator
,
Union
from
typing
import
List
,
Optional
from
apps.webui.models.auths
import
Auths
from
apps.webui.models.models
import
Models
,
ModelModel
from
apps.webui.models.models
import
Models
from
apps.webui.models.tools
import
Tools
from
apps.webui.models.functions
import
Functions
from
apps.webui.models.users
import
Users
...
...
@@ -83,14 +79,12 @@ from utils.task import (
from
utils.misc
import
(
get_last_user_message
,
add_or_update_system_message
,
stream_message_template
,
parse_duration
,
)
from
apps.rag.utils
import
get_rag_context
,
rag_template
from
config
import
(
CONFIG_DATA
,
WEBUI_NAME
,
WEBUI_URL
,
WEBUI_AUTH
,
...
...
@@ -98,7 +92,6 @@ from config import (
VERSION
,
CHANGELOG
,
FRONTEND_BUILD_DIR
,
UPLOAD_DIR
,
CACHE_DIR
,
STATIC_DIR
,
DEFAULT_LOCALE
,
...
...
@@ -126,7 +119,8 @@ from config import (
WEBUI_SESSION_COOKIE_SECURE
,
AppConfig
,
)
from
constants
import
ERROR_MESSAGES
,
WEBHOOK_MESSAGES
from
constants
import
ERROR_MESSAGES
,
WEBHOOK_MESSAGES
,
TASKS
from
utils.webhook
import
post_webhook
if
SAFE_MODE
:
...
...
@@ -167,8 +161,20 @@ https://github.com/open-webui/open-webui
)
def
run_migrations
():
try
:
from
alembic.config
import
Config
from
alembic
import
command
alembic_cfg
=
Config
(
"alembic.ini"
)
command
.
upgrade
(
alembic_cfg
,
"head"
)
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
run_migrations
()
yield
...
...
@@ -212,8 +218,79 @@ origins = ["*"]
##################################
async
def
get_body_and_model_and_user
(
request
):
# Read the original request body
body
=
await
request
.
body
()
body_str
=
body
.
decode
(
"utf-8"
)
body
=
json
.
loads
(
body_str
)
if
body_str
else
{}
model_id
=
body
[
"model"
]
if
model_id
not
in
app
.
state
.
MODELS
:
raise
Exception
(
"Model not found"
)
model
=
app
.
state
.
MODELS
[
model_id
]
user
=
get_current_user
(
request
,
get_http_authorization_cred
(
request
.
headers
.
get
(
"Authorization"
)),
)
return
body
,
model
,
user
def
get_task_model_id
(
default_model_id
):
# Set the task model
task_model_id
=
default_model_id
# Check if the user has a custom task model and use that model
if
app
.
state
.
MODELS
[
task_model_id
][
"owned_by"
]
==
"ollama"
:
if
(
app
.
state
.
config
.
TASK_MODEL
and
app
.
state
.
config
.
TASK_MODEL
in
app
.
state
.
MODELS
):
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
else
:
if
(
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
and
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
in
app
.
state
.
MODELS
):
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
return
task_model_id
def
get_filter_function_ids
(
model
):
def
get_priority
(
function_id
):
function
=
Functions
.
get_function_by_id
(
function_id
)
if
function
is
not
None
and
hasattr
(
function
,
"valves"
):
return
(
function
.
valves
if
function
.
valves
else
{}).
get
(
"priority"
,
0
)
return
0
filter_ids
=
[
function
.
id
for
function
in
Functions
.
get_global_filter_functions
()]
if
"info"
in
model
and
"meta"
in
model
[
"info"
]:
filter_ids
.
extend
(
model
[
"info"
][
"meta"
].
get
(
"filterIds"
,
[]))
filter_ids
=
list
(
set
(
filter_ids
))
enabled_filter_ids
=
[
function
.
id
for
function
in
Functions
.
get_functions_by_type
(
"filter"
,
active_only
=
True
)
]
filter_ids
=
[
filter_id
for
filter_id
in
filter_ids
if
filter_id
in
enabled_filter_ids
]
filter_ids
.
sort
(
key
=
get_priority
)
return
filter_ids
async
def
get_function_call_response
(
messages
,
files
,
tool_id
,
template
,
task_model_id
,
user
messages
,
files
,
tool_id
,
template
,
task_model_id
,
user
,
__event_emitter__
=
None
,
__event_call__
=
None
,
):
tool
=
Tools
.
get_tool_by_id
(
tool_id
)
tools_specs
=
json
.
dumps
(
tool
.
specs
,
indent
=
2
)
...
...
@@ -240,6 +317,7 @@ async def get_function_call_response(
{
"role"
:
"user"
,
"content"
:
f
"Query:
{
prompt
}
"
},
],
"stream"
:
False
,
"task"
:
TASKS
.
FUNCTION_CALLING
,
}
try
:
...
...
@@ -252,7 +330,6 @@ async def get_function_call_response(
response
=
None
try
:
response
=
await
generate_chat_completions
(
form_data
=
payload
,
user
=
user
)
content
=
None
if
hasattr
(
response
,
"body_iterator"
):
...
...
@@ -266,334 +343,367 @@ async def get_function_call_response(
else
:
content
=
response
[
"choices"
][
0
][
"message"
][
"content"
]
if
content
is
None
:
return
None
,
None
,
False
# Parse the function response
if
content
is
not
None
:
print
(
f
"content:
{
content
}
"
)
result
=
json
.
loads
(
content
)
print
(
result
)
citation
=
None
# Call the function
if
"name"
in
result
:
if
tool_id
in
webui_app
.
state
.
TOOLS
:
toolkit_module
=
webui_app
.
state
.
TOOLS
[
tool_id
]
else
:
toolkit_module
,
frontmatter
=
load_toolkit_module_by_id
(
tool_id
)
webui_app
.
state
.
TOOLS
[
tool_id
]
=
toolkit_module
file_handler
=
False
# check if toolkit_module has file_handler self variable
if
hasattr
(
toolkit_module
,
"file_handler"
):
file_handler
=
True
print
(
"file_handler: "
,
file_handler
)
if
hasattr
(
toolkit_module
,
"valves"
)
and
hasattr
(
toolkit_module
,
"Valves"
):
valves
=
Tools
.
get_tool_valves_by_id
(
tool_id
)
toolkit_module
.
valves
=
toolkit_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
print
(
f
"content:
{
content
}
"
)
result
=
json
.
loads
(
content
)
print
(
result
)
function
=
getattr
(
toolkit_module
,
result
[
"name"
])
function_result
=
None
try
:
# Get the signature of the function
sig
=
inspect
.
signature
(
function
)
params
=
result
[
"parameters"
]
citation
=
None
if
"__user__"
in
sig
.
parameters
:
# Call the function with the '__user__' parameter included
__user__
=
{
"id"
:
user
.
id
,
"email"
:
user
.
email
,
"name"
:
user
.
name
,
"role"
:
user
.
role
,
}
try
:
if
hasattr
(
toolkit_module
,
"UserValves"
):
__user__
[
"valves"
]
=
toolkit_module
.
UserValves
(
**
Tools
.
get_user_valves_by_id_and_user_id
(
tool_id
,
user
.
id
)
)
except
Exception
as
e
:
print
(
e
)
params
=
{
**
params
,
"__user__"
:
__user__
}
if
"__messages__"
in
sig
.
parameters
:
# Call the function with the '__messages__' parameter included
params
=
{
**
params
,
"__messages__"
:
messages
,
}
if
"__files__"
in
sig
.
parameters
:
# Call the function with the '__files__' parameter included
params
=
{
**
params
,
"__files__"
:
files
,
}
if
"__model__"
in
sig
.
parameters
:
# Call the function with the '__model__' parameter included
params
=
{
**
params
,
"__model__"
:
model
,
}
if
"__id__"
in
sig
.
parameters
:
# Call the function with the '__id__' parameter included
params
=
{
**
params
,
"__id__"
:
tool_id
,
}
if
inspect
.
iscoroutinefunction
(
function
):
function_result
=
await
function
(
**
params
)
else
:
function_result
=
function
(
**
params
)
if
hasattr
(
toolkit_module
,
"citation"
)
and
toolkit_module
.
citation
:
citation
=
{
"source"
:
{
"name"
:
f
"TOOL:
{
tool
.
name
}
/
{
result
[
'name'
]
}
"
},
"document"
:
[
function_result
],
"metadata"
:
[{
"source"
:
result
[
"name"
]}],
}
if
"name"
not
in
result
:
return
None
,
None
,
False
# Call the function
if
tool_id
in
webui_app
.
state
.
TOOLS
:
toolkit_module
=
webui_app
.
state
.
TOOLS
[
tool_id
]
else
:
toolkit_module
,
_
=
load_toolkit_module_by_id
(
tool_id
)
webui_app
.
state
.
TOOLS
[
tool_id
]
=
toolkit_module
file_handler
=
False
# check if toolkit_module has file_handler self variable
if
hasattr
(
toolkit_module
,
"file_handler"
):
file_handler
=
True
print
(
"file_handler: "
,
file_handler
)
if
hasattr
(
toolkit_module
,
"valves"
)
and
hasattr
(
toolkit_module
,
"Valves"
):
valves
=
Tools
.
get_tool_valves_by_id
(
tool_id
)
toolkit_module
.
valves
=
toolkit_module
.
Valves
(
**
(
valves
if
valves
else
{}))
function
=
getattr
(
toolkit_module
,
result
[
"name"
])
function_result
=
None
try
:
# Get the signature of the function
sig
=
inspect
.
signature
(
function
)
params
=
result
[
"parameters"
]
# Extra parameters to be passed to the function
extra_params
=
{
"__model__"
:
model
,
"__id__"
:
tool_id
,
"__messages__"
:
messages
,
"__files__"
:
files
,
"__event_emitter__"
:
__event_emitter__
,
"__event_call__"
:
__event_call__
,
}
# Add extra params in contained in function signature
for
key
,
value
in
extra_params
.
items
():
if
key
in
sig
.
parameters
:
params
[
key
]
=
value
if
"__user__"
in
sig
.
parameters
:
# Call the function with the '__user__' parameter included
__user__
=
{
"id"
:
user
.
id
,
"email"
:
user
.
email
,
"name"
:
user
.
name
,
"role"
:
user
.
role
,
}
try
:
if
hasattr
(
toolkit_module
,
"UserValves"
):
__user__
[
"valves"
]
=
toolkit_module
.
UserValves
(
**
Tools
.
get_user_valves_by_id_and_user_id
(
tool_id
,
user
.
id
)
)
except
Exception
as
e
:
print
(
e
)
# Add the function result to the system prompt
if
function_result
is
not
None
:
return
function_result
,
citation
,
file_handler
params
=
{
**
params
,
"__user__"
:
__user__
}
if
inspect
.
iscoroutinefunction
(
function
):
function_result
=
await
function
(
**
params
)
else
:
function_result
=
function
(
**
params
)
if
hasattr
(
toolkit_module
,
"citation"
)
and
toolkit_module
.
citation
:
citation
=
{
"source"
:
{
"name"
:
f
"TOOL:
{
tool
.
name
}
/
{
result
[
'name'
]
}
"
},
"document"
:
[
function_result
],
"metadata"
:
[{
"source"
:
result
[
"name"
]}],
}
except
Exception
as
e
:
print
(
e
)
# Add the function result to the system prompt
if
function_result
is
not
None
:
return
function_result
,
citation
,
file_handler
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
return
None
,
None
,
False
class
ChatCompletionMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
data_items
=
[]
async
def
chat_completion_functions_handler
(
body
,
model
,
user
,
__event_emitter__
,
__event_call__
):
skip_files
=
None
filter_ids
=
get_filter_function_ids
(
model
)
for
filter_id
in
filter_ids
:
filter
=
Functions
.
get_function_by_id
(
filter_id
)
if
not
filter
:
continue
if
filter_id
in
webui_app
.
state
.
FUNCTIONS
:
function_module
=
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
else
:
function_module
,
_
,
_
=
load_function_module_by_id
(
filter_id
)
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
=
function_module
# Check if the function has a file_handler variable
if
hasattr
(
function_module
,
"file_handler"
):
skip_files
=
function_module
.
file_handler
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
filter_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
if
not
hasattr
(
function_module
,
"inlet"
):
continue
try
:
inlet
=
function_module
.
inlet
# Get the signature of the function
sig
=
inspect
.
signature
(
inlet
)
params
=
{
"body"
:
body
}
# Extra parameters to be passed to the function
extra_params
=
{
"__model__"
:
model
,
"__id__"
:
filter_id
,
"__event_emitter__"
:
__event_emitter__
,
"__event_call__"
:
__event_call__
,
}
# Add extra params in contained in function signature
for
key
,
value
in
extra_params
.
items
():
if
key
in
sig
.
parameters
:
params
[
key
]
=
value
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
"id"
:
user
.
id
,
"email"
:
user
.
email
,
"name"
:
user
.
name
,
"role"
:
user
.
role
,
}
try
:
if
hasattr
(
function_module
,
"UserValves"
):
__user__
[
"valves"
]
=
function_module
.
UserValves
(
**
Functions
.
get_user_valves_by_id_and_user_id
(
filter_id
,
user
.
id
)
)
except
Exception
as
e
:
print
(
e
)
params
=
{
**
params
,
"__user__"
:
__user__
}
if
inspect
.
iscoroutinefunction
(
inlet
):
body
=
await
inlet
(
**
params
)
else
:
body
=
inlet
(
**
params
)
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
raise
e
if
skip_files
:
if
"files"
in
body
:
del
body
[
"files"
]
return
body
,
{}
async
def
chat_completion_tools_handler
(
body
,
user
,
__event_emitter__
,
__event_call__
):
skip_files
=
None
contexts
=
[]
citations
=
None
task_model_id
=
get_task_model_id
(
body
[
"model"
])
# If tool_ids field is present, call the functions
if
"tool_ids"
in
body
:
print
(
body
[
"tool_ids"
])
for
tool_id
in
body
[
"tool_ids"
]:
print
(
tool_id
)
try
:
response
,
citation
,
file_handler
=
await
get_function_call_response
(
messages
=
body
[
"messages"
],
files
=
body
.
get
(
"files"
,
[]),
tool_id
=
tool_id
,
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
,
task_model_id
=
task_model_id
,
user
=
user
,
__event_emitter__
=
__event_emitter__
,
__event_call__
=
__event_call__
,
)
show_citations
=
False
citations
=
[]
print
(
file_handler
)
if
isinstance
(
response
,
str
):
contexts
.
append
(
response
)
if
citation
:
if
citations
is
None
:
citations
=
[
citation
]
else
:
citations
.
append
(
citation
)
if
file_handler
:
skip_files
=
True
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
del
body
[
"tool_ids"
]
print
(
f
"tool_contexts:
{
contexts
}
"
)
if
skip_files
:
if
"files"
in
body
:
del
body
[
"files"
]
return
body
,
{
**
({
"contexts"
:
contexts
}
if
contexts
is
not
None
else
{}),
**
({
"citations"
:
citations
}
if
citations
is
not
None
else
{}),
}
async
def
chat_completion_files_handler
(
body
):
contexts
=
[]
citations
=
None
if
"files"
in
body
:
files
=
body
[
"files"
]
del
body
[
"files"
]
contexts
,
citations
=
get_rag_context
(
files
=
files
,
messages
=
body
[
"messages"
],
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
k
=
rag_app
.
state
.
config
.
TOP_K
,
reranking_function
=
rag_app
.
state
.
sentence_transformer_rf
,
r
=
rag_app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
hybrid_search
=
rag_app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
)
log
.
debug
(
f
"rag_contexts:
{
contexts
}
, citations:
{
citations
}
"
)
return
body
,
{
**
({
"contexts"
:
contexts
}
if
contexts
is
not
None
else
{}),
**
({
"citations"
:
citations
}
if
citations
is
not
None
else
{}),
}
class
ChatCompletionMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
if
request
.
method
==
"POST"
and
any
(
endpoint
in
request
.
url
.
path
for
endpoint
in
[
"/ollama/api/chat"
,
"/chat/completions"
]
):
log
.
debug
(
f
"request.url.path:
{
request
.
url
.
path
}
"
)
# Read the original request body
body
=
await
request
.
body
()
body_str
=
body
.
decode
(
"utf-8"
)
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
user
=
get_current_user
(
request
,
get_http_authorization_cred
(
request
.
headers
.
get
(
"Authorization"
)),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files
=
False
if
data
.
get
(
"citations"
):
show_citations
=
True
del
data
[
"citations"
]
model_id
=
data
[
"model"
]
if
model_id
not
in
app
.
state
.
MODELS
:
raise
HTTPException
(
status_code
=
status
.
HTTP_404_NOT_FOUND
,
detail
=
"Model not found"
,
try
:
body
,
model
,
user
=
await
get_body_and_model_and_user
(
request
)
except
Exception
as
e
:
return
JSONResponse
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
content
=
{
"detail"
:
str
(
e
)},
)
model
=
app
.
state
.
MODELS
[
model_id
]
def
get_priority
(
function_id
):
function
=
Functions
.
get_function_by_id
(
function_id
)
if
function
is
not
None
and
hasattr
(
function
,
"valves"
):
return
(
function
.
valves
if
function
.
valves
else
{}).
get
(
"priority"
,
0
)
return
0
filter_ids
=
[
function
.
id
for
function
in
Functions
.
get_global_filter_functions
()
]
if
"info"
in
model
and
"meta"
in
model
[
"info"
]:
filter_ids
.
extend
(
model
[
"info"
][
"meta"
].
get
(
"filterIds"
,
[]))
filter_ids
=
list
(
set
(
filter_ids
))
enabled_filter_ids
=
[
function
.
id
for
function
in
Functions
.
get_functions_by_type
(
"filter"
,
active_only
=
True
# Extract session_id, chat_id and message_id from the request body
session_id
=
None
if
"session_id"
in
body
:
session_id
=
body
[
"session_id"
]
del
body
[
"session_id"
]
chat_id
=
None
if
"chat_id"
in
body
:
chat_id
=
body
[
"chat_id"
]
del
body
[
"chat_id"
]
message_id
=
None
if
"id"
in
body
:
message_id
=
body
[
"id"
]
del
body
[
"id"
]
async
def
__event_emitter__
(
data
):
await
sio
.
emit
(
"chat-events"
,
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"data"
:
data
,
},
to
=
session_id
,
)
]
filter_ids
=
[
filter_id
for
filter_id
in
filter_ids
if
filter_id
in
enabled_filter_ids
]
filter_ids
.
sort
(
key
=
get_priority
)
for
filter_id
in
filter_ids
:
filter
=
Functions
.
get_function_by_id
(
filter_id
)
if
filter
:
if
filter_id
in
webui_app
.
state
.
FUNCTIONS
:
function_module
=
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
else
:
function_module
,
function_type
,
frontmatter
=
(
load_function_module_by_id
(
filter_id
)
)
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
=
function_module
# Check if the function has a file_handler variable
if
hasattr
(
function_module
,
"file_handler"
):
skip_files
=
function_module
.
file_handler
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
filter_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
async
def
__event_call__
(
data
):
response
=
await
sio
.
call
(
"chat-events"
,
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"data"
:
data
},
to
=
session_id
,
)
return
response
try
:
if
hasattr
(
function_module
,
"inlet"
):
inlet
=
function_module
.
inlet
# Get the signature of the function
sig
=
inspect
.
signature
(
inlet
)
params
=
{
"body"
:
data
}
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
"id"
:
user
.
id
,
"email"
:
user
.
email
,
"name"
:
user
.
name
,
"role"
:
user
.
role
,
}
try
:
if
hasattr
(
function_module
,
"UserValves"
):
__user__
[
"valves"
]
=
function_module
.
UserValves
(
**
Functions
.
get_user_valves_by_id_and_user_id
(
filter_id
,
user
.
id
)
)
except
Exception
as
e
:
print
(
e
)
params
=
{
**
params
,
"__user__"
:
__user__
}
if
"__id__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__id__"
:
filter_id
,
}
if
inspect
.
iscoroutinefunction
(
inlet
):
data
=
await
inlet
(
**
params
)
else
:
data
=
inlet
(
**
params
)
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
return
JSONResponse
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
content
=
{
"detail"
:
str
(
e
)},
)
# Initialize data_items to store additional data to be sent to the client
data_items
=
[]
# Set the task model
task_model_id
=
data
[
"model"
]
# Check if the user has a custom task model and use that model
if
app
.
state
.
MODELS
[
task_model_id
][
"owned_by"
]
==
"ollama"
:
if
(
app
.
state
.
config
.
TASK_MODEL
and
app
.
state
.
config
.
TASK_MODEL
in
app
.
state
.
MODELS
):
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
else
:
if
(
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
and
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
in
app
.
state
.
MODELS
):
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
prompt
=
get_last_user_message
(
data
[
"messages"
])
context
=
""
# If tool_ids field is present, call the functions
if
"tool_ids"
in
data
:
print
(
data
[
"tool_ids"
])
for
tool_id
in
data
[
"tool_ids"
]:
print
(
tool_id
)
try
:
response
,
citation
,
file_handler
=
(
await
get_function_call_response
(
messages
=
data
[
"messages"
],
files
=
data
.
get
(
"files"
,
[]),
tool_id
=
tool_id
,
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
,
task_model_id
=
task_model_id
,
user
=
user
,
)
)
# Initialize context, and citations
contexts
=
[]
citations
=
[]
print
(
file_handler
)
if
isinstance
(
response
,
str
):
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
response
if
citation
:
citations
.
append
(
citation
)
show_citations
=
True
if
file_handler
:
skip_files
=
True
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
del
data
[
"tool_ids"
]
print
(
f
"tool_context:
{
context
}
"
)
# If files field is present, generate RAG completions
# If skip_files is True, skip the RAG completions
if
"files"
in
data
:
if
not
skip_files
:
data
=
{
**
data
}
rag_context
,
rag_citations
=
get_rag_context
(
files
=
data
[
"files"
],
messages
=
data
[
"messages"
],
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
k
=
rag_app
.
state
.
config
.
TOP_K
,
reranking_function
=
rag_app
.
state
.
sentence_transformer_rf
,
r
=
rag_app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
hybrid_search
=
rag_app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
)
if
rag_context
:
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
rag_context
try
:
body
,
flags
=
await
chat_completion_functions_handler
(
body
,
model
,
user
,
__event_emitter__
,
__event_call__
)
except
Exception
as
e
:
return
JSONResponse
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
content
=
{
"detail"
:
str
(
e
)},
)
log
.
debug
(
f
"rag_context:
{
rag_context
}
, citations:
{
citations
}
"
)
try
:
body
,
flags
=
await
chat_completion_tools_handler
(
body
,
user
,
__event_emitter__
,
__event_call__
)
if
rag_citations
:
citations
.
extend
(
rag_citations
)
contexts
.
extend
(
flags
.
get
(
"contexts"
,
[]))
citations
.
extend
(
flags
.
get
(
"citations"
,
[]))
except
Exception
as
e
:
print
(
e
)
pass
del
data
[
"files"
]
try
:
body
,
flags
=
await
chat_completion_files_handler
(
body
)
if
show_citations
and
len
(
citations
)
>
0
:
data_items
.
append
({
"citations"
:
citations
})
contexts
.
extend
(
flags
.
get
(
"contexts"
,
[]))
citations
.
extend
(
flags
.
get
(
"citations"
,
[]))
except
Exception
as
e
:
print
(
e
)
pass
if
context
!=
""
:
system_prompt
=
rag_template
(
rag_app
.
state
.
config
.
RAG_TEMPLATE
,
context
,
prompt
)
print
(
system_prompt
)
data
[
"messages"
]
=
add_or_update_system_message
(
system_prompt
,
data
[
"messages"
]
# If context is not empty, insert it into the messages
if
len
(
contexts
)
>
0
:
context_string
=
"/n"
.
join
(
contexts
).
strip
()
prompt
=
get_last_user_message
(
body
[
"messages"
])
body
[
"messages"
]
=
add_or_update_system_message
(
rag_template
(
rag_app
.
state
.
config
.
RAG_TEMPLATE
,
context_string
,
prompt
),
body
[
"messages"
],
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
# If there are citations, add them to the data_items
if
len
(
citations
)
>
0
:
data_items
.
append
({
"citations"
:
citations
})
modified_body_bytes
=
json
.
dumps
(
body
).
encode
(
"utf-8"
)
# Replace the request body with the modified one
request
.
_body
=
modified_body_bytes
# Set custom header to ensure content-length matches new body length
...
...
@@ -654,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware)
##################################
def
filter_pipeline
(
payload
,
user
):
user
=
{
"id"
:
user
.
id
,
"email"
:
user
.
email
,
"name"
:
user
.
name
,
"role"
:
user
.
role
}
model_id
=
payload
[
"model"
]
def
get_sorted_filters
(
model_id
):
filters
=
[
model
for
model
in
app
.
state
.
MODELS
.
values
()
...
...
@@ -672,6 +780,13 @@ def filter_pipeline(payload, user):
)
]
sorted_filters
=
sorted
(
filters
,
key
=
lambda
x
:
x
[
"pipeline"
][
"priority"
])
return
sorted_filters
def
filter_pipeline
(
payload
,
user
):
user
=
{
"id"
:
user
.
id
,
"email"
:
user
.
email
,
"name"
:
user
.
name
,
"role"
:
user
.
role
}
model_id
=
payload
[
"model"
]
sorted_filters
=
get_sorted_filters
(
model_id
)
model
=
app
.
state
.
MODELS
[
model_id
]
...
...
@@ -704,25 +819,12 @@ def filter_pipeline(payload, user):
print
(
f
"Connection error:
{
e
}
"
)
if
r
is
not
None
:
try
:
res
=
r
.
json
()
except
:
pass
res
=
r
.
json
()
if
"detail"
in
res
:
raise
Exception
(
r
.
status_code
,
res
[
"detail"
])
else
:
pass
if
"pipeline"
not
in
app
.
state
.
MODELS
[
model_id
]:
if
"chat_id"
in
payload
:
del
payload
[
"chat_id"
]
if
"title"
in
payload
:
del
payload
[
"title"
]
if
"task"
in
payload
:
del
payload
[
"task"
]
if
"pipeline"
not
in
app
.
state
.
MODELS
[
model_id
]
and
"task"
in
payload
:
del
payload
[
"task"
]
return
payload
...
...
@@ -787,6 +889,14 @@ app.add_middleware(
)
@
app
.
middleware
(
"http"
)
async
def
commit_session_after_request
(
request
:
Request
,
call_next
):
response
=
await
call_next
(
request
)
log
.
debug
(
"Commit session after request"
)
Session
.
commit
()
return
response
@
app
.
middleware
(
"http"
)
async
def
check_url
(
request
:
Request
,
call_next
):
if
len
(
app
.
state
.
MODELS
)
==
0
:
...
...
@@ -863,12 +973,16 @@ async def get_all_models():
model
[
"info"
]
=
custom_model
.
model_dump
()
else
:
owned_by
=
"openai"
pipe
=
None
for
model
in
models
:
if
(
custom_model
.
base_model_id
==
model
[
"id"
]
or
custom_model
.
base_model_id
==
model
[
"id"
].
split
(
":"
)[
0
]
):
owned_by
=
model
[
"owned_by"
]
if
"pipe"
in
model
:
pipe
=
model
[
"pipe"
]
break
models
.
append
(
...
...
@@ -880,11 +994,11 @@ async def get_all_models():
"owned_by"
:
owned_by
,
"info"
:
custom_model
.
model_dump
(),
"preset"
:
True
,
**
({
"pipe"
:
pipe
}
if
pipe
is
not
None
else
{}),
}
)
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
}
webui_app
.
state
.
MODELS
=
app
.
state
.
MODELS
return
models
...
...
@@ -945,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
)
model
=
app
.
state
.
MODELS
[
model_id
]
filters
=
[
model
for
model
in
app
.
state
.
MODELS
.
values
()
if
"pipeline"
in
model
and
"type"
in
model
[
"pipeline"
]
and
model
[
"pipeline"
][
"type"
]
==
"filter"
and
(
model
[
"pipeline"
][
"pipelines"
]
==
[
"*"
]
or
any
(
model_id
==
target_model_id
for
target_model_id
in
model
[
"pipeline"
][
"pipelines"
]
)
)
]
sorted_filters
=
sorted
(
filters
,
key
=
lambda
x
:
x
[
"pipeline"
][
"priority"
])
sorted_filters
=
get_sorted_filters
(
model_id
)
if
"pipeline"
in
model
:
sorted_filters
=
[
model
]
+
sorted_filters
...
...
@@ -1008,6 +1107,25 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else
:
pass
async
def
__event_emitter__
(
event_data
):
await
sio
.
emit
(
"chat-events"
,
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"data"
:
event_data
,
},
to
=
data
[
"session_id"
],
)
async
def
__event_call__
(
event_data
):
response
=
await
sio
.
call
(
"chat-events"
,
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"data"
:
event_data
},
to
=
data
[
"session_id"
],
)
return
response
def
get_priority
(
function_id
):
function
=
Functions
.
get_function_by_id
(
function_id
)
if
function
is
not
None
and
hasattr
(
function
,
"valves"
):
...
...
@@ -1032,68 +1150,74 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
for
filter_id
in
filter_ids
:
filter
=
Functions
.
get_function_by_id
(
filter_id
)
if
filter
:
if
filter_id
in
webui_app
.
state
.
FUNCTIONS
:
function_module
=
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
else
:
function_module
,
function_type
,
frontmatter
=
(
load_function_module_by_id
(
filter_id
)
)
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
=
function_module
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
filter_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
if
not
filter
:
continue
try
:
if
hasattr
(
function_module
,
"outlet"
):
outlet
=
function_module
.
outlet
if
filter_id
in
webui_app
.
state
.
FUNCTIONS
:
function_module
=
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
else
:
function_module
,
_
,
_
=
load_function_module_by_id
(
filter_id
)
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
=
function_module
# Get the signature of the function
sig
=
inspect
.
signature
(
outlet
)
params
=
{
"body"
:
data
}
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
filter_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
"id"
:
user
.
id
,
"email"
:
user
.
email
,
"name"
:
user
.
name
,
"role"
:
user
.
role
,
}
try
:
if
hasattr
(
function_module
,
"UserValves"
):
__user__
[
"valves"
]
=
function_module
.
UserValves
(
**
Functions
.
get_user_valves_by_id_and_user_id
(
filter_id
,
user
.
id
)
)
except
Exception
as
e
:
print
(
e
)
params
=
{
**
params
,
"__user__"
:
__user__
}
if
"__id__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__id__"
:
filter_id
,
}
if
inspect
.
iscoroutinefunction
(
outlet
):
data
=
await
outlet
(
**
params
)
else
:
data
=
outlet
(
**
params
)
if
not
hasattr
(
function_module
,
"outlet"
):
continue
try
:
outlet
=
function_module
.
outlet
# Get the signature of the function
sig
=
inspect
.
signature
(
outlet
)
params
=
{
"body"
:
data
}
# Extra parameters to be passed to the function
extra_params
=
{
"__model__"
:
model
,
"__id__"
:
filter_id
,
"__event_emitter__"
:
__event_emitter__
,
"__event_call__"
:
__event_call__
,
}
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
return
JSONResponse
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
content
=
{
"detail"
:
str
(
e
)},
)
# Add extra params in contained in function signature
for
key
,
value
in
extra_params
.
items
():
if
key
in
sig
.
parameters
:
params
[
key
]
=
value
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
"id"
:
user
.
id
,
"email"
:
user
.
email
,
"name"
:
user
.
name
,
"role"
:
user
.
role
,
}
try
:
if
hasattr
(
function_module
,
"UserValves"
):
__user__
[
"valves"
]
=
function_module
.
UserValves
(
**
Functions
.
get_user_valves_by_id_and_user_id
(
filter_id
,
user
.
id
)
)
except
Exception
as
e
:
print
(
e
)
params
=
{
**
params
,
"__user__"
:
__user__
}
if
inspect
.
iscoroutinefunction
(
outlet
):
data
=
await
outlet
(
**
params
)
else
:
data
=
outlet
(
**
params
)
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
return
JSONResponse
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
content
=
{
"detail"
:
str
(
e
)},
)
return
data
...
...
@@ -1169,19 +1293,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if
app
.
state
.
MODELS
[
model_id
][
"owned_by"
]
==
"ollama"
:
if
app
.
state
.
config
.
TASK_MODEL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
else
:
if
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
model_id
=
get_task_model_id
(
model_id
)
print
(
model_id
)
model
=
app
.
state
.
MODELS
[
model_id
]
template
=
app
.
state
.
config
.
TITLE_GENERATION_PROMPT_TEMPLATE
...
...
@@ -1200,7 +1314,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream"
:
False
,
"max_tokens"
:
50
,
"chat_id"
:
form_data
.
get
(
"chat_id"
,
None
),
"t
itle"
:
True
,
"t
ask"
:
TASKS
.
TITLE_GENERATION
,
}
log
.
debug
(
payload
)
...
...
@@ -1213,6 +1327,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content
=
{
"detail"
:
e
.
args
[
1
]},
)
if
"chat_id"
in
payload
:
del
payload
[
"chat_id"
]
return
await
generate_chat_completions
(
form_data
=
payload
,
user
=
user
)
...
...
@@ -1235,19 +1352,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if
app
.
state
.
MODELS
[
model_id
][
"owned_by"
]
==
"ollama"
:
if
app
.
state
.
config
.
TASK_MODEL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
else
:
if
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
model_id
=
get_task_model_id
(
model_id
)
print
(
model_id
)
model
=
app
.
state
.
MODELS
[
model_id
]
template
=
app
.
state
.
config
.
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
...
...
@@ -1260,7 +1367,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages"
:
[{
"role"
:
"user"
,
"content"
:
content
}],
"stream"
:
False
,
"max_tokens"
:
30
,
"task"
:
T
rue
,
"task"
:
T
ASKS
.
QUERY_GENERATION
,
}
print
(
payload
)
...
...
@@ -1273,6 +1380,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content
=
{
"detail"
:
e
.
args
[
1
]},
)
if
"chat_id"
in
payload
:
del
payload
[
"chat_id"
]
return
await
generate_chat_completions
(
form_data
=
payload
,
user
=
user
)
...
...
@@ -1289,19 +1399,9 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if
app
.
state
.
MODELS
[
model_id
][
"owned_by"
]
==
"ollama"
:
if
app
.
state
.
config
.
TASK_MODEL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
else
:
if
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
model_id
=
get_task_model_id
(
model_id
)
print
(
model_id
)
model
=
app
.
state
.
MODELS
[
model_id
]
template
=
'''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
...
...
@@ -1324,7 +1424,7 @@ Message: """{{prompt}}"""
"stream"
:
False
,
"max_tokens"
:
4
,
"chat_id"
:
form_data
.
get
(
"chat_id"
,
None
),
"task"
:
T
rue
,
"task"
:
T
ASKS
.
EMOJI_GENERATION
,
}
log
.
debug
(
payload
)
...
...
@@ -1337,6 +1437,9 @@ Message: """{{prompt}}"""
content
=
{
"detail"
:
e
.
args
[
1
]},
)
if
"chat_id"
in
payload
:
del
payload
[
"chat_id"
]
return
await
generate_chat_completions
(
form_data
=
payload
,
user
=
user
)
...
...
@@ -1353,22 +1456,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if
app
.
state
.
MODELS
[
model_id
][
"owned_by"
]
==
"ollama"
:
if
app
.
state
.
config
.
TASK_MODEL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
else
:
if
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
:
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
if
task_model_id
in
app
.
state
.
MODELS
:
model_id
=
task_model_id
model_id
=
get_task_model_id
(
model_id
)
print
(
model_id
)
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try
:
context
,
citation
,
file_handler
=
await
get_function_call_response
(
context
,
_
,
_
=
await
get_function_call_response
(
form_data
[
"messages"
],
form_data
.
get
(
"files"
,
[]),
form_data
[
"tool_id"
],
...
...
@@ -1432,6 +1526,7 @@ async def upload_pipeline(
os
.
makedirs
(
upload_folder
,
exist_ok
=
True
)
file_path
=
os
.
path
.
join
(
upload_folder
,
file
.
filename
)
r
=
None
try
:
# Save the uploaded file
with
open
(
file_path
,
"wb"
)
as
buffer
:
...
...
@@ -1455,7 +1550,9 @@ async def upload_pipeline(
print
(
f
"Connection error:
{
e
}
"
)
detail
=
"Pipeline not found"
status_code
=
status
.
HTTP_404_NOT_FOUND
if
r
is
not
None
:
status_code
=
r
.
status_code
try
:
res
=
r
.
json
()
if
"detail"
in
res
:
...
...
@@ -1464,7 +1561,7 @@ async def upload_pipeline(
pass
raise
HTTPException
(
status_code
=
(
r
.
status_code
if
r
is
not
None
else
status
.
HTTP_404_NOT_FOUND
)
,
status_code
=
status_code
,
detail
=
detail
,
)
finally
:
...
...
@@ -1563,8 +1660,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
async
def
get_pipelines
(
urlIdx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
)):
r
=
None
try
:
urlIdx
url
=
openai_app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
urlIdx
]
key
=
openai_app
.
state
.
config
.
OPENAI_API_KEYS
[
urlIdx
]
...
...
@@ -1596,7 +1691,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@
app
.
get
(
"/api/pipelines/{pipeline_id}/valves"
)
async
def
get_pipeline_valves
(
urlIdx
:
Optional
[
int
],
pipeline_id
:
str
,
user
=
Depends
(
get_admin_user
)
urlIdx
:
Optional
[
int
],
pipeline_id
:
str
,
user
=
Depends
(
get_admin_user
),
):
models
=
await
get_all_models
()
r
=
None
...
...
@@ -1634,7 +1731,9 @@ async def get_pipeline_valves(
@
app
.
get
(
"/api/pipelines/{pipeline_id}/valves/spec"
)
async
def
get_pipeline_valves_spec
(
urlIdx
:
Optional
[
int
],
pipeline_id
:
str
,
user
=
Depends
(
get_admin_user
)
urlIdx
:
Optional
[
int
],
pipeline_id
:
str
,
user
=
Depends
(
get_admin_user
),
):
models
=
await
get_all_models
()
...
...
@@ -1920,7 +2019,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
if
existing_user
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EMAIL_TAKEN
)
picture_url
=
user_data
.
get
(
"picture"
,
""
)
picture_claim
=
webui_app
.
state
.
config
.
OAUTH_PICTURE_CLAIM
picture_url
=
user_data
.
get
(
picture_claim
,
""
)
if
picture_url
:
# Download the profile image into a base64 string
try
:
...
...
@@ -1940,6 +2040,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
picture_url
=
""
if
not
picture_url
:
picture_url
=
"/user.png"
username_claim
=
webui_app
.
state
.
config
.
OAUTH_USERNAME_CLAIM
role
=
(
"admin"
if
Users
.
get_num_users
()
==
0
...
...
@@ -1950,7 +2051,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
password
=
get_password_hash
(
str
(
uuid
.
uuid4
())
),
# Random password, not used
name
=
user_data
.
get
(
"name"
,
"User"
),
name
=
user_data
.
get
(
username_claim
,
"User"
),
profile_image_url
=
picture_url
,
role
=
role
,
oauth_sub
=
provider_sub
,
...
...
@@ -2008,7 +2109,7 @@ async def get_opensearch_xml():
<ShortName>
{
WEBUI_NAME
}
</ShortName>
<Description>Search
{
WEBUI_NAME
}
</Description>
<InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">
{
WEBUI_URL
}
/favicon.png</Image>
<Image width="16" height="16" type="image/x-icon">
{
WEBUI_URL
}
/
static/
favicon.png</Image>
<Url type="text/html" method="get" template="
{
WEBUI_URL
}
/?q=
{
"
{
searchTerms
}
"
}
"/>
<moz:SearchForm>
{
WEBUI_URL
}
</moz:SearchForm>
</OpenSearchDescription>
...
...
@@ -2021,6 +2122,12 @@ async def healthcheck():
return
{
"status"
:
True
}
@
app
.
get
(
"/health/db"
)
async
def
healthcheck_with_db
():
Session
.
execute
(
text
(
"SELECT 1;"
)).
all
()
return
{
"status"
:
True
}
app
.
mount
(
"/static"
,
StaticFiles
(
directory
=
STATIC_DIR
),
name
=
"static"
)
app
.
mount
(
"/cache"
,
StaticFiles
(
directory
=
CACHE_DIR
),
name
=
"cache"
)
...
...
backend/migrations/README
0 → 100644
View file @
9bcd4ce5
Generic single-database configuration.
Create new migrations with
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"
backend/migrations/env.py
0 → 100644
View file @
9bcd4ce5
import
os
from
logging.config
import
fileConfig
from
sqlalchemy
import
engine_from_config
from
sqlalchemy
import
pool
from
alembic
import
context
from
apps.webui.models.auths
import
Auth
from
apps.webui.models.chats
import
Chat
from
apps.webui.models.documents
import
Document
from
apps.webui.models.memories
import
Memory
from
apps.webui.models.models
import
Model
from
apps.webui.models.prompts
import
Prompt
from
apps.webui.models.tags
import
Tag
,
ChatIdTag
from
apps.webui.models.tools
import
Tool
from
apps.webui.models.users
import
User
from
apps.webui.models.files
import
File
from
apps.webui.models.functions
import
Function
from
config
import
DATABASE_URL
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config
=
context
.
config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if
config
.
config_file_name
is
not
None
:
fileConfig
(
config
.
config_file_name
)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata
=
Auth
.
metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
DB_URL
=
DATABASE_URL
if
DB_URL
:
config
.
set_main_option
(
"sqlalchemy.url"
,
DB_URL
.
replace
(
"%"
,
"%%"
))
def
run_migrations_offline
()
->
None
:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url
=
config
.
get_main_option
(
"sqlalchemy.url"
)
context
.
configure
(
url
=
url
,
target_metadata
=
target_metadata
,
literal_binds
=
True
,
dialect_opts
=
{
"paramstyle"
:
"named"
},
)
with
context
.
begin_transaction
():
context
.
run_migrations
()
def
run_migrations_online
()
->
None
:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable
=
engine_from_config
(
config
.
get_section
(
config
.
config_ini_section
,
{}),
prefix
=
"sqlalchemy."
,
poolclass
=
pool
.
NullPool
,
)
with
connectable
.
connect
()
as
connection
:
context
.
configure
(
connection
=
connection
,
target_metadata
=
target_metadata
)
with
context
.
begin_transaction
():
context
.
run_migrations
()
if
context
.
is_offline_mode
():
run_migrations_offline
()
else
:
run_migrations_online
()
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment