Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
e3e02e04
Commit
e3e02e04
authored
Jul 09, 2024
by
Michael Poluektov
Browse files
refac: backend/main.py
parent
f9e3c47d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
214 additions
and
280 deletions
+214
-280
backend/main.py
backend/main.py
+214
-280
No files found.
backend/main.py
View file @
e3e02e04
import
base64
import
uuid
import
subprocess
from
contextlib
import
asynccontextmanager
from
authlib.integrations.starlette_client
import
OAuth
from
authlib.oidc.core
import
UserInfo
from
bs4
import
BeautifulSoup
import
json
import
markdown
import
time
import
os
import
sys
...
...
@@ -19,14 +16,11 @@ import shutil
import
os
import
uuid
import
inspect
import
asyncio
from
fastapi.concurrency
import
run_in_threadpool
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
from
fastapi.staticfiles
import
StaticFiles
from
fastapi.responses
import
JSONResponse
from
fastapi
import
HTTPException
from
fastapi.middleware.wsgi
import
WSGIMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
sqlalchemy
import
text
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
...
...
@@ -38,7 +32,6 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse
from
apps.socket.main
import
sio
,
app
as
socket_app
from
apps.ollama.main
import
(
app
as
ollama_app
,
OpenAIChatCompletionForm
,
get_all_models
as
get_ollama_models
,
generate_openai_chat_completion
as
generate_ollama_chat_completion
,
)
...
...
@@ -56,14 +49,14 @@ from apps.webui.main import (
get_pipe_models
,
generate_function_chat_completion
,
)
from
apps.webui.internal.db
import
Session
,
SessionLocal
from
apps.webui.internal.db
import
Session
from
pydantic
import
BaseModel
from
typing
import
List
,
Optional
,
Iterator
,
Generator
,
Union
from
typing
import
List
,
Optional
from
apps.webui.models.auths
import
Auths
from
apps.webui.models.models
import
Models
,
ModelModel
from
apps.webui.models.models
import
Models
from
apps.webui.models.tools
import
Tools
from
apps.webui.models.functions
import
Functions
from
apps.webui.models.users
import
Users
...
...
@@ -86,14 +79,12 @@ from utils.task import (
from
utils.misc
import
(
get_last_user_message
,
add_or_update_system_message
,
stream_message_template
,
parse_duration
,
)
from
apps.rag.utils
import
get_rag_context
,
rag_template
from
config
import
(
CONFIG_DATA
,
WEBUI_NAME
,
WEBUI_URL
,
WEBUI_AUTH
,
...
...
@@ -101,7 +92,6 @@ from config import (
VERSION
,
CHANGELOG
,
FRONTEND_BUILD_DIR
,
UPLOAD_DIR
,
CACHE_DIR
,
STATIC_DIR
,
DEFAULT_LOCALE
,
...
...
@@ -128,9 +118,8 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE
,
WEBUI_SESSION_COOKIE_SECURE
,
AppConfig
,
BACKEND_DIR
,
DATABASE_URL
,
)
from
constants
import
ERROR_MESSAGES
,
WEBHOOK_MESSAGES
,
TASKS
from
utils.webhook
import
post_webhook
...
...
@@ -355,19 +344,24 @@ async def get_function_call_response(
else
:
content
=
response
[
"choices"
][
0
][
"message"
][
"content"
]
if
content
is
None
:
return
None
,
None
,
False
# Parse the function response
if
content
is
not
None
:
print
(
f
"content:
{
content
}
"
)
result
=
json
.
loads
(
content
)
print
(
result
)
citation
=
None
if
"name"
not
in
result
:
return
None
,
None
,
False
# Call the function
if
"name"
in
result
:
if
tool_id
in
webui_app
.
state
.
TOOLS
:
toolkit_module
=
webui_app
.
state
.
TOOLS
[
tool_id
]
else
:
toolkit_module
,
frontmatter
=
load_toolkit_module_by_id
(
tool_id
)
toolkit_module
,
_
=
load_toolkit_module_by_id
(
tool_id
)
webui_app
.
state
.
TOOLS
[
tool_id
]
=
toolkit_module
file_handler
=
False
...
...
@@ -376,13 +370,9 @@ async def get_function_call_response(
file_handler
=
True
print
(
"file_handler: "
,
file_handler
)
if
hasattr
(
toolkit_module
,
"valves"
)
and
hasattr
(
toolkit_module
,
"Valves"
):
if
hasattr
(
toolkit_module
,
"valves"
)
and
hasattr
(
toolkit_module
,
"Valves"
):
valves
=
Tools
.
get_tool_valves_by_id
(
tool_id
)
toolkit_module
.
valves
=
toolkit_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
toolkit_module
.
valves
=
toolkit_module
.
Valves
(
**
(
valves
if
valves
else
{}))
function
=
getattr
(
toolkit_module
,
result
[
"name"
])
function_result
=
None
...
...
@@ -391,6 +381,21 @@ async def get_function_call_response(
sig
=
inspect
.
signature
(
function
)
params
=
result
[
"parameters"
]
# Extra parameters to be passed to the function
extra_params
=
{
"__model__"
:
model
,
"__id__"
:
tool_id
,
"__messages__"
:
messages
,
"__files__"
:
files
,
"__event_emitter__"
:
__event_emitter__
,
"__event_call__"
:
__event_call__
,
}
# Add extra params in contained in function signature
for
key
,
value
in
extra_params
.
items
():
if
key
in
sig
.
parameters
:
params
[
key
]
=
value
if
"__user__"
in
sig
.
parameters
:
# Call the function with the '__user__' parameter included
__user__
=
{
...
...
@@ -403,55 +408,12 @@ async def get_function_call_response(
try
:
if
hasattr
(
toolkit_module
,
"UserValves"
):
__user__
[
"valves"
]
=
toolkit_module
.
UserValves
(
**
Tools
.
get_user_valves_by_id_and_user_id
(
tool_id
,
user
.
id
)
**
Tools
.
get_user_valves_by_id_and_user_id
(
tool_id
,
user
.
id
)
)
except
Exception
as
e
:
print
(
e
)
params
=
{
**
params
,
"__user__"
:
__user__
}
if
"__messages__"
in
sig
.
parameters
:
# Call the function with the '__messages__' parameter included
params
=
{
**
params
,
"__messages__"
:
messages
,
}
if
"__files__"
in
sig
.
parameters
:
# Call the function with the '__files__' parameter included
params
=
{
**
params
,
"__files__"
:
files
,
}
if
"__model__"
in
sig
.
parameters
:
# Call the function with the '__model__' parameter included
params
=
{
**
params
,
"__model__"
:
model
,
}
if
"__id__"
in
sig
.
parameters
:
# Call the function with the '__id__' parameter included
params
=
{
**
params
,
"__id__"
:
tool_id
,
}
if
"__event_emitter__"
in
sig
.
parameters
:
# Call the function with the '__event_emitter__' parameter included
params
=
{
**
params
,
"__event_emitter__"
:
__event_emitter__
,
}
if
"__event_call__"
in
sig
.
parameters
:
# Call the function with the '__event_call__' parameter included
params
=
{
**
params
,
"__event_call__"
:
__event_call__
,
}
if
inspect
.
iscoroutinefunction
(
function
):
function_result
=
await
function
(
**
params
)
...
...
@@ -484,22 +446,20 @@ async def chat_completion_functions_handler(
filter_ids
=
get_filter_function_ids
(
model
)
for
filter_id
in
filter_ids
:
filter
=
Functions
.
get_function_by_id
(
filter_id
)
if
filter
:
if
not
filter
:
continue
if
filter_id
in
webui_app
.
state
.
FUNCTIONS
:
function_module
=
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
else
:
function_module
,
function_type
,
frontmatter
=
(
load_function_module_by_id
(
filter_id
)
)
function_module
,
_
,
_
=
load_function_module_by_id
(
filter_id
)
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
=
function_module
# Check if the function has a file_handler variable
if
hasattr
(
function_module
,
"file_handler"
):
skip_files
=
function_module
.
file_handler
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
filter_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
...
...
@@ -513,6 +473,19 @@ async def chat_completion_functions_handler(
sig
=
inspect
.
signature
(
inlet
)
params
=
{
"body"
:
body
}
# Extra parameters to be passed to the function
extra_params
=
{
"__model__"
:
model
,
"__id__"
:
filter_id
,
"__event_emitter__"
:
__event_emitter__
,
"__event_call__"
:
__event_call__
,
}
# Add extra params in contained in function signature
for
key
,
value
in
extra_params
.
items
():
if
key
in
sig
.
parameters
:
params
[
key
]
=
value
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
"id"
:
user
.
id
,
...
...
@@ -533,30 +506,6 @@ async def chat_completion_functions_handler(
params
=
{
**
params
,
"__user__"
:
__user__
}
if
"__id__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__id__"
:
filter_id
,
}
if
"__model__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__model__"
:
model
,
}
if
"__event_emitter__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__event_emitter__"
:
__event_emitter__
,
}
if
"__event_call__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__event_call__"
:
__event_call__
,
}
if
inspect
.
iscoroutinefunction
(
inlet
):
body
=
await
inlet
(
**
params
)
else
:
...
...
@@ -1220,18 +1169,16 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
for
filter_id
in
filter_ids
:
filter
=
Functions
.
get_function_by_id
(
filter_id
)
if
filter
:
if
not
filter
:
continue
if
filter_id
in
webui_app
.
state
.
FUNCTIONS
:
function_module
=
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
else
:
function_module
,
function_type
,
frontmatter
=
(
load_function_module_by_id
(
filter_id
)
)
function_module
,
_
,
_
=
load_function_module_by_id
(
filter_id
)
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
=
function_module
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
filter_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
...
...
@@ -1245,6 +1192,19 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
sig
=
inspect
.
signature
(
outlet
)
params
=
{
"body"
:
data
}
# Extra parameters to be passed to the function
extra_params
=
{
"__model__"
:
model
,
"__id__"
:
filter_id
,
"__event_emitter__"
:
__event_emitter__
,
"__event_call__"
:
__event_call__
,
}
# Add extra params in contained in function signature
for
key
,
value
in
extra_params
.
items
():
if
key
in
sig
.
parameters
:
params
[
key
]
=
value
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
"id"
:
user
.
id
,
...
...
@@ -1265,30 +1225,6 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
params
=
{
**
params
,
"__user__"
:
__user__
}
if
"__id__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__id__"
:
filter_id
,
}
if
"__model__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__model__"
:
model
,
}
if
"__event_emitter__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__event_emitter__"
:
__event_emitter__
,
}
if
"__event_call__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__event_call__"
:
__event_call__
,
}
if
inspect
.
iscoroutinefunction
(
outlet
):
data
=
await
outlet
(
**
params
)
else
:
...
...
@@ -1387,7 +1323,6 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
model_id
=
task_model_id
print
(
model_id
)
model
=
app
.
state
.
MODELS
[
model_id
]
template
=
app
.
state
.
config
.
TITLE_GENERATION_PROMPT_TEMPLATE
...
...
@@ -1456,7 +1391,6 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
model_id
=
task_model_id
print
(
model_id
)
model
=
app
.
state
.
MODELS
[
model_id
]
template
=
app
.
state
.
config
.
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
...
...
@@ -1513,7 +1447,6 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
model_id
=
task_model_id
print
(
model_id
)
model
=
app
.
state
.
MODELS
[
model_id
]
template
=
'''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
...
...
@@ -1583,7 +1516,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try
:
context
,
citation
,
file_handler
=
await
get_function_call_response
(
context
,
_
,
_
=
await
get_function_call_response
(
form_data
[
"messages"
],
form_data
.
get
(
"files"
,
[]),
form_data
[
"tool_id"
],
...
...
@@ -1647,6 +1580,7 @@ async def upload_pipeline(
os
.
makedirs
(
upload_folder
,
exist_ok
=
True
)
file_path
=
os
.
path
.
join
(
upload_folder
,
file
.
filename
)
r
=
None
try
:
# Save the uploaded file
with
open
(
file_path
,
"wb"
)
as
buffer
:
...
...
@@ -1670,7 +1604,9 @@ async def upload_pipeline(
print
(
f
"Connection error:
{
e
}
"
)
detail
=
"Pipeline not found"
status_code
=
status
.
HTTP_404_NOT_FOUND
if
r
is
not
None
:
status_code
=
r
.
status_code
try
:
res
=
r
.
json
()
if
"detail"
in
res
:
...
...
@@ -1679,7 +1615,7 @@ async def upload_pipeline(
pass
raise
HTTPException
(
status_code
=
(
r
.
status_code
if
r
is
not
None
else
status
.
HTTP_404_NOT_FOUND
)
,
status_code
=
status_code
,
detail
=
detail
,
)
finally
:
...
...
@@ -1778,8 +1714,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
async
def
get_pipelines
(
urlIdx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
)):
r
=
None
try
:
urlIdx
url
=
openai_app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
urlIdx
]
key
=
openai_app
.
state
.
config
.
OPENAI_API_KEYS
[
urlIdx
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment