Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
a9a6ed8b
Unverified
Commit
a9a6ed8b
authored
Aug 02, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Aug 02, 2024
Browse files
Merge pull request #4237 from michaelpoluektov/refactor-webui-main
refactor: Simplify functions
parents
64b41655
e6c64282
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
284 additions
and
389 deletions
+284
-389
backend/apps/socket/main.py
backend/apps/socket/main.py
+2
-4
backend/apps/webui/main.py
backend/apps/webui/main.py
+225
-302
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+2
-10
backend/main.py
backend/main.py
+16
-41
backend/utils/misc.py
backend/utils/misc.py
+39
-32
No files found.
backend/apps/socket/main.py
View file @
a9a6ed8b
...
@@ -52,7 +52,6 @@ async def user_join(sid, data):
...
@@ -52,7 +52,6 @@ async def user_join(sid, data):
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
if
user
:
if
user
:
SESSION_POOL
[
sid
]
=
user
.
id
SESSION_POOL
[
sid
]
=
user
.
id
if
user
.
id
in
USER_POOL
:
if
user
.
id
in
USER_POOL
:
USER_POOL
[
user
.
id
].
append
(
sid
)
USER_POOL
[
user
.
id
].
append
(
sid
)
...
@@ -80,7 +79,6 @@ def get_models_in_use():
...
@@ -80,7 +79,6 @@ def get_models_in_use():
@
sio
.
on
(
"usage"
)
@
sio
.
on
(
"usage"
)
async
def
usage
(
sid
,
data
):
async
def
usage
(
sid
,
data
):
model_id
=
data
[
"model"
]
model_id
=
data
[
"model"
]
# Cancel previous callback if there is one
# Cancel previous callback if there is one
...
@@ -139,7 +137,7 @@ async def disconnect(sid):
...
@@ -139,7 +137,7 @@ async def disconnect(sid):
print
(
f
"Unknown session ID
{
sid
}
disconnected"
)
print
(
f
"Unknown session ID
{
sid
}
disconnected"
)
async
def
get_event_emitter
(
request_info
):
def
get_event_emitter
(
request_info
):
async
def
__event_emitter__
(
event_data
):
async
def
__event_emitter__
(
event_data
):
await
sio
.
emit
(
await
sio
.
emit
(
"chat-events"
,
"chat-events"
,
...
@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
...
@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
return
__event_emitter__
return
__event_emitter__
async
def
get_event_call
(
request_info
):
def
get_event_call
(
request_info
):
async
def
__event_call__
(
event_data
):
async
def
__event_call__
(
event_data
):
response
=
await
sio
.
call
(
response
=
await
sio
.
call
(
"chat-events"
,
"chat-events"
,
...
...
backend/apps/webui/main.py
View file @
a9a6ed8b
from
fastapi
import
FastAPI
,
Depends
from
fastapi
import
FastAPI
from
fastapi.routing
import
APIRoute
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
sqlalchemy.orm
import
Session
from
apps.webui.routers
import
(
from
apps.webui.routers
import
(
auths
,
auths
,
users
,
users
,
...
@@ -22,12 +19,15 @@ from apps.webui.models.functions import Functions
...
@@ -22,12 +19,15 @@ from apps.webui.models.functions import Functions
from
apps.webui.models.models
import
Models
from
apps.webui.models.models
import
Models
from
apps.webui.utils
import
load_function_module_by_id
from
apps.webui.utils
import
load_function_module_by_id
from
utils.misc
import
stream_message_template
from
utils.misc
import
(
openai_chat_chunk_message_template
,
openai_chat_completion_message_template
,
add_or_update_system_message
,
)
from
utils.task
import
prompt_template
from
utils.task
import
prompt_template
from
config
import
(
from
config
import
(
WEBUI_BUILD_HASH
,
SHOW_ADMIN_DETAILS
,
SHOW_ADMIN_DETAILS
,
ADMIN_EMAIL
,
ADMIN_EMAIL
,
WEBUI_AUTH
,
WEBUI_AUTH
,
...
@@ -51,11 +51,9 @@ from config import (
...
@@ -51,11 +51,9 @@ from config import (
from
apps.socket.main
import
get_event_call
,
get_event_emitter
from
apps.socket.main
import
get_event_call
,
get_event_emitter
import
inspect
import
inspect
import
uuid
import
time
import
json
import
json
from
typing
import
Iterator
,
Generator
,
AsyncGenerator
,
Optional
from
typing
import
Iterator
,
Generator
,
AsyncGenerator
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
app
=
FastAPI
()
app
=
FastAPI
()
...
@@ -127,29 +125,29 @@ async def get_status():
...
@@ -127,29 +125,29 @@ async def get_status():
}
}
def
get_function_module
(
pipe_id
:
str
):
# Check if function is already loaded
if
pipe_id
not
in
app
.
state
.
FUNCTIONS
:
function_module
,
_
,
_
=
load_function_module_by_id
(
pipe_id
)
app
.
state
.
FUNCTIONS
[
pipe_id
]
=
function_module
else
:
function_module
=
app
.
state
.
FUNCTIONS
[
pipe_id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
pipe_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{}))
return
function_module
async
def
get_pipe_models
():
async
def
get_pipe_models
():
pipes
=
Functions
.
get_functions_by_type
(
"pipe"
,
active_only
=
True
)
pipes
=
Functions
.
get_functions_by_type
(
"pipe"
,
active_only
=
True
)
pipe_models
=
[]
pipe_models
=
[]
for
pipe
in
pipes
:
for
pipe
in
pipes
:
# Check if function is already loaded
function_module
=
get_function_module
(
pipe
.
id
)
if
pipe
.
id
not
in
app
.
state
.
FUNCTIONS
:
function_module
,
function_type
,
frontmatter
=
load_function_module_by_id
(
pipe
.
id
)
app
.
state
.
FUNCTIONS
[
pipe
.
id
]
=
function_module
else
:
function_module
=
app
.
state
.
FUNCTIONS
[
pipe
.
id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
pipe
.
id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
# Check if function is a manifold
# Check if function is a manifold
if
hasattr
(
function_module
,
"type"
):
if
hasattr
(
function_module
,
"pipes"
):
if
function_module
.
type
==
"manifold"
:
manifold_pipes
=
[]
manifold_pipes
=
[]
# Check if pipes is a function or a list
# Check if pipes is a function or a list
...
@@ -163,9 +161,7 @@ async def get_pipe_models():
...
@@ -163,9 +161,7 @@ async def get_pipe_models():
manifold_pipe_name
=
p
[
"name"
]
manifold_pipe_name
=
p
[
"name"
]
if
hasattr
(
function_module
,
"name"
):
if
hasattr
(
function_module
,
"name"
):
manifold_pipe_name
=
(
manifold_pipe_name
=
f
"
{
function_module
.
name
}{
manifold_pipe_name
}
"
f
"
{
function_module
.
name
}{
manifold_pipe_name
}
"
)
pipe_flag
=
{
"type"
:
pipe
.
type
}
pipe_flag
=
{
"type"
:
pipe
.
type
}
if
hasattr
(
function_module
,
"ChatValves"
):
if
hasattr
(
function_module
,
"ChatValves"
):
...
@@ -200,127 +196,59 @@ async def get_pipe_models():
...
@@ -200,127 +196,59 @@ async def get_pipe_models():
return
pipe_models
return
pipe_models
async
def
generate_function_chat_completion
(
form_data
,
user
):
async
def
execute_pipe
(
pipe
,
params
):
model_id
=
form_data
.
get
(
"model"
)
if
inspect
.
iscoroutinefunction
(
pipe
):
model_info
=
Models
.
get_model_by_id
(
model_id
)
return
await
pipe
(
**
params
)
else
:
metadata
=
None
return
pipe
(
**
params
)
if
"metadata"
in
form_data
:
metadata
=
form_data
[
"metadata"
]
del
form_data
[
"metadata"
]
__event_emitter__
=
None
__event_call__
=
None
__task__
=
None
if
metadata
:
if
(
metadata
.
get
(
"session_id"
)
and
metadata
.
get
(
"chat_id"
)
and
metadata
.
get
(
"message_id"
)
):
__event_emitter__
=
await
get_event_emitter
(
metadata
)
__event_call__
=
await
get_event_call
(
metadata
)
if
metadata
.
get
(
"task"
):
__task__
=
metadata
.
get
(
"task"
)
if
model_info
:
if
model_info
.
base_model_id
:
form_data
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
if
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
:
form_data
[
"temperature"
]
=
float
(
model_info
.
params
.
get
(
"temperature"
))
if
model_info
.
params
.
get
(
"top_p"
,
None
):
async
def
get_message_content
(
res
:
str
|
Generator
|
AsyncGenerator
)
->
str
:
form_data
[
"top_p"
]
=
int
(
model_info
.
params
.
get
(
"top_p"
,
None
))
if
isinstance
(
res
,
str
):
return
res
if
isinstance
(
res
,
Generator
):
return
""
.
join
(
map
(
str
,
res
))
if
isinstance
(
res
,
AsyncGenerator
):
return
""
.
join
([
str
(
stream
)
async
for
stream
in
res
])
if
model_info
.
params
.
get
(
"max_tokens"
,
None
):
form_data
[
"max_tokens"
]
=
int
(
model_info
.
params
.
get
(
"max_tokens"
,
None
))
if
model_info
.
params
.
get
(
"frequency_penalty"
,
None
):
def
process_line
(
form_data
:
dict
,
line
):
form_data
[
"frequency_penalty"
]
=
int
(
if
isinstance
(
line
,
BaseModel
):
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
line
=
line
.
model_dump_json
()
)
line
=
f
"data:
{
line
}
"
if
isinstance
(
line
,
dict
):
line
=
f
"data:
{
json
.
dumps
(
line
)
}
"
if
model_info
.
params
.
get
(
"seed"
,
None
):
try
:
form_data
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
line
=
line
.
decode
(
"utf-8"
)
except
Exception
:
if
model_info
.
params
.
get
(
"stop"
,
None
):
pass
form_data
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
line
.
startswith
(
"data:"
):
if
system
:
return
f
"
{
line
}
\n\n
"
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if
form_data
.
get
(
"messages"
):
for
message
in
form_data
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
system
+
message
[
"content"
]
break
else
:
else
:
form_data
[
"messages"
].
insert
(
line
=
openai_chat_chunk_message_template
(
form_data
[
"model"
],
line
)
0
,
return
f
"data:
{
json
.
dumps
(
line
)
}
\n\n
"
{
"role"
:
"system"
,
"content"
:
system
,
},
)
else
:
pass
async
def
job
()
:
def
get_pipe_id
(
form_data
:
dict
)
->
str
:
pipe_id
=
form_data
[
"model"
]
pipe_id
=
form_data
[
"model"
]
if
"."
in
pipe_id
:
if
"."
in
pipe_id
:
pipe_id
,
sub_pipe_id
=
pipe_id
.
split
(
"."
,
1
)
pipe_id
,
_
=
pipe_id
.
split
(
"."
,
1
)
print
(
pipe_id
)
print
(
pipe_id
)
return
pipe_id
# Check if function is already loaded
if
pipe_id
not
in
app
.
state
.
FUNCTIONS
:
function_module
,
function_type
,
frontmatter
=
load_function_module_by_id
(
pipe_id
)
app
.
state
.
FUNCTIONS
[
pipe_id
]
=
function_module
else
:
function_module
=
app
.
state
.
FUNCTIONS
[
pipe_id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
pipe_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
pipe
=
function_module
.
pipe
def
get_function_params
(
function_module
,
form_data
,
user
,
extra_params
=
{}):
pipe_id
=
get_pipe_id
(
form_data
)
# Get the signature of the function
# Get the signature of the function
sig
=
inspect
.
signature
(
pipe
)
sig
=
inspect
.
signature
(
function_module
.
pipe
)
params
=
{
"body"
:
form_data
}
params
=
{
"body"
:
form_data
}
for
key
,
value
in
extra_params
.
items
():
if
key
in
sig
.
parameters
:
params
[
key
]
=
value
if
"__user__"
in
sig
.
parameters
:
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
__user__
=
{
"id"
:
user
.
id
,
"id"
:
user
.
id
,
...
@@ -337,25 +265,94 @@ async def generate_function_chat_completion(form_data, user):
...
@@ -337,25 +265,94 @@ async def generate_function_chat_completion(form_data, user):
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
params
=
{
**
params
,
"__user__"
:
__user__
}
params
[
"__user__"
]
=
__user__
return
params
# inplace function: form_data is modified
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
if
not
params
:
return
form_data
mappings
=
{
"temperature"
:
float
,
"top_p"
:
int
,
"max_tokens"
:
int
,
"frequency_penalty"
:
int
,
"seed"
:
lambda
x
:
x
,
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
}
for
key
,
cast_func
in
mappings
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
return
form_data
# inplace function: form_data is modified
def
apply_model_system_prompt_to_body
(
params
:
dict
,
form_data
:
dict
,
user
)
->
dict
:
system
=
params
.
get
(
"system"
,
None
)
if
not
system
:
return
form_data
if
user
:
template_params
=
{
"user_name"
:
user
.
name
,
"user_location"
:
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
,
}
else
:
template_params
=
{}
system
=
prompt_template
(
system
,
**
template_params
)
form_data
[
"messages"
]
=
add_or_update_system_message
(
system
,
form_data
.
get
(
"messages"
,
[])
)
return
form_data
if
"__event_emitter__"
in
sig
.
parameters
:
async
def
generate_function_chat_completion
(
form_data
,
user
):
params
=
{
**
params
,
"__event_emitter__"
:
__event_emitter__
}
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
metadata
=
form_data
.
pop
(
"metadata"
,
None
)
if
"__event_call__"
in
sig
.
parameters
:
__event_emitter__
=
None
params
=
{
**
params
,
"__event_call__"
:
__event_call__
}
__event_call__
=
None
__task__
=
None
if
"__task__"
in
sig
.
parameters
:
if
metadata
:
params
=
{
**
params
,
"__task__"
:
__task__
}
if
all
(
k
in
metadata
for
k
in
(
"session_id"
,
"chat_id"
,
"message_id"
)):
__event_emitter__
=
get_event_emitter
(
metadata
)
__event_call__
=
get_event_call
(
metadata
)
__task__
=
metadata
.
get
(
"task"
,
None
)
if
model_info
:
if
model_info
.
base_model_id
:
form_data
[
"model"
]
=
model_info
.
base_model_id
params
=
model_info
.
params
.
model_dump
()
form_data
=
apply_model_params_to_body
(
params
,
form_data
)
form_data
=
apply_model_system_prompt_to_body
(
params
,
form_data
,
user
)
pipe_id
=
get_pipe_id
(
form_data
)
function_module
=
get_function_module
(
pipe_id
)
pipe
=
function_module
.
pipe
params
=
get_function_params
(
function_module
,
form_data
,
user
,
{
"__event_emitter__"
:
__event_emitter__
,
"__event_call__"
:
__event_call__
,
"__task__"
:
__task__
,
},
)
if
form_data
[
"stream"
]:
if
form_data
[
"stream"
]:
async
def
stream_content
():
async
def
stream_content
():
try
:
try
:
if
inspect
.
iscoroutinefunction
(
pipe
):
res
=
await
execute_pipe
(
pipe
,
params
)
res
=
await
pipe
(
**
params
)
else
:
res
=
pipe
(
**
params
)
# Directly return if the response is a StreamingResponse
# Directly return if the response is a StreamingResponse
if
isinstance
(
res
,
StreamingResponse
):
if
isinstance
(
res
,
StreamingResponse
):
...
@@ -372,112 +369,38 @@ async def generate_function_chat_completion(form_data, user):
...
@@ -372,112 +369,38 @@ async def generate_function_chat_completion(form_data, user):
return
return
if
isinstance
(
res
,
str
):
if
isinstance
(
res
,
str
):
message
=
stream
_message_template
(
form_data
[
"model"
],
res
)
message
=
openai_chat_chunk
_message_template
(
form_data
[
"model"
],
res
)
yield
f
"data:
{
json
.
dumps
(
message
)
}
\n\n
"
yield
f
"data:
{
json
.
dumps
(
message
)
}
\n\n
"
if
isinstance
(
res
,
Iterator
):
if
isinstance
(
res
,
Iterator
):
for
line
in
res
:
for
line
in
res
:
if
isinstance
(
line
,
BaseModel
):
yield
process_line
(
form_data
,
line
)
line
=
line
.
model_dump_json
()
line
=
f
"data:
{
line
}
"
if
isinstance
(
line
,
dict
):
line
=
f
"data:
{
json
.
dumps
(
line
)
}
"
try
:
line
=
line
.
decode
(
"utf-8"
)
except
:
pass
if
line
.
startswith
(
"data:"
):
yield
f
"
{
line
}
\n\n
"
else
:
line
=
stream_message_template
(
form_data
[
"model"
],
line
)
yield
f
"data:
{
json
.
dumps
(
line
)
}
\n\n
"
if
isinstance
(
res
,
str
)
or
isinstance
(
res
,
Generator
):
finish_message
=
{
"id"
:
f
"
{
form_data
[
'model'
]
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"model"
:
form_data
[
"model"
],
"choices"
:
[
{
"index"
:
0
,
"delta"
:
{},
"logprobs"
:
None
,
"finish_reason"
:
"stop"
,
}
],
}
yield
f
"data:
{
json
.
dumps
(
finish_message
)
}
\n\n
"
yield
f
"data: [DONE]"
if
isinstance
(
res
,
AsyncGenerator
):
if
isinstance
(
res
,
AsyncGenerator
):
async
for
line
in
res
:
async
for
line
in
res
:
if
isinstance
(
line
,
BaseModel
):
yield
process_line
(
form_data
,
line
)
line
=
line
.
model_dump_json
()
line
=
f
"data:
{
line
}
"
if
isinstance
(
line
,
dict
):
line
=
f
"data:
{
json
.
dumps
(
line
)
}
"
try
:
line
=
line
.
decode
(
"utf-8"
)
except
:
pass
if
line
.
startswith
(
"data:"
):
if
isinstance
(
res
,
str
)
or
isinstance
(
res
,
Generator
):
yield
f
"
{
line
}
\n\n
"
finish_message
=
openai_chat_chunk_message_template
(
else
:
form_data
[
"model"
],
""
line
=
stream_message_template
(
form_data
[
"model"
],
line
)
)
yield
f
"data:
{
json
.
dumps
(
line
)
}
\n\n
"
finish_message
[
"choices"
][
0
][
"finish_reason"
]
=
"stop"
yield
f
"data:
{
json
.
dumps
(
finish_message
)
}
\n\n
"
yield
"data: [DONE]"
return
StreamingResponse
(
stream_content
(),
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
stream_content
(),
media_type
=
"text/event-stream"
)
else
:
else
:
try
:
try
:
if
inspect
.
iscoroutinefunction
(
pipe
):
res
=
await
execute_pipe
(
pipe
,
params
)
res
=
await
pipe
(
**
params
)
else
:
res
=
pipe
(
**
params
)
if
isinstance
(
res
,
StreamingResponse
):
return
res
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
print
(
f
"Error:
{
e
}
"
)
return
{
"error"
:
{
"detail"
:
str
(
e
)}}
return
{
"error"
:
{
"detail"
:
str
(
e
)}}
if
isinstance
(
res
,
dict
):
if
isinstance
(
res
,
StreamingResponse
)
or
isinstance
(
res
,
dict
):
return
res
return
res
el
if
isinstance
(
res
,
BaseModel
):
if
isinstance
(
res
,
BaseModel
):
return
res
.
model_dump
()
return
res
.
model_dump
()
else
:
message
=
""
if
isinstance
(
res
,
str
):
message
=
res
elif
isinstance
(
res
,
Generator
):
for
stream
in
res
:
message
=
f
"
{
message
}{
stream
}
"
elif
isinstance
(
res
,
AsyncGenerator
):
async
for
stream
in
res
:
message
=
f
"
{
message
}{
stream
}
"
return
{
"id"
:
f
"
{
form_data
[
'model'
]
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"object"
:
"chat.completion"
,
"created"
:
int
(
time
.
time
()),
"model"
:
form_data
[
"model"
],
"choices"
:
[
{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
message
,
},
"logprobs"
:
None
,
"finish_reason"
:
"stop"
,
}
],
}
return
await
job
()
message
=
await
get_message_content
(
res
)
return
openai_chat_completion_message_template
(
form_data
[
"model"
],
message
)
backend/apps/webui/models/models.py
View file @
a9a6ed8b
import
json
import
logging
import
logging
from
typing
import
Optional
from
typing
import
Optional
,
List
from
pydantic
import
BaseModel
,
ConfigDict
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
sqlalchemy
import
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
from
config
import
SRC_LOG_LEVELS
import
time
import
time
...
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
...
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class
ModelsTable
:
class
ModelsTable
:
def
insert_new_model
(
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
self
,
form_data
:
ModelForm
,
user_id
:
str
)
->
Optional
[
ModelModel
]:
)
->
Optional
[
ModelModel
]:
...
@@ -126,9 +123,7 @@ class ModelsTable:
...
@@ -126,9 +123,7 @@ class ModelsTable:
}
}
)
)
try
:
try
:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
add
(
result
)
db
.
commit
()
db
.
commit
()
...
@@ -144,13 +139,11 @@ class ModelsTable:
...
@@ -144,13 +139,11 @@ class ModelsTable:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
try
:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
model
=
db
.
get
(
Model
,
id
)
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
return
ModelModel
.
model_validate
(
model
)
except
:
except
:
...
@@ -178,7 +171,6 @@ class ModelsTable:
...
@@ -178,7 +171,6 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
try
:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
db
.
commit
()
...
...
backend/main.py
View file @
a9a6ed8b
...
@@ -13,8 +13,6 @@ import aiohttp
...
@@ -13,8 +13,6 @@ import aiohttp
import
requests
import
requests
import
mimetypes
import
mimetypes
import
shutil
import
shutil
import
os
import
uuid
import
inspect
import
inspect
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
...
@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
...
@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
apps.socket.main
import
sio
,
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.socket.main
import
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.ollama.main
import
(
from
apps.ollama.main
import
(
app
as
ollama_app
,
app
as
ollama_app
,
get_all_models
as
get_ollama_models
,
get_all_models
as
get_ollama_models
,
...
@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content
=
{
"detail"
:
str
(
e
)},
content
=
{
"detail"
:
str
(
e
)},
)
)
# Extract valves from the request body
metadata
=
{
valves
=
None
"chat_id"
:
body
.
pop
(
"chat_id"
,
None
),
if
"valves"
in
body
:
"message_id"
:
body
.
pop
(
"id"
,
None
),
valves
=
body
[
"valves"
]
"session_id"
:
body
.
pop
(
"session_id"
,
None
),
del
body
[
"valves"
]
"valves"
:
body
.
pop
(
"valves"
,
None
),
}
# Extract session_id, chat_id and message_id from the request body
__event_emitter__
=
get_event_emitter
(
metadata
)
session_id
=
None
__event_call__
=
get_event_call
(
metadata
)
if
"session_id"
in
body
:
session_id
=
body
[
"session_id"
]
del
body
[
"session_id"
]
chat_id
=
None
if
"chat_id"
in
body
:
chat_id
=
body
[
"chat_id"
]
del
body
[
"chat_id"
]
message_id
=
None
if
"id"
in
body
:
message_id
=
body
[
"id"
]
del
body
[
"id"
]
__event_emitter__
=
await
get_event_emitter
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
__event_call__
=
await
get_event_call
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
# Initialize data_items to store additional data to be sent to the client
# Initialize data_items to store additional data to be sent to the client
data_items
=
[]
data_items
=
[]
...
@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if
len
(
citations
)
>
0
:
if
len
(
citations
)
>
0
:
data_items
.
append
({
"citations"
:
citations
})
data_items
.
append
({
"citations"
:
citations
})
body
[
"metadata"
]
=
{
body
[
"metadata"
]
=
metadata
"session_id"
:
session_id
,
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"valves"
:
valves
,
}
modified_body_bytes
=
json
.
dumps
(
body
).
encode
(
"utf-8"
)
modified_body_bytes
=
json
.
dumps
(
body
).
encode
(
"utf-8"
)
# Replace the request body with the modified one
# Replace the request body with the modified one
request
.
_body
=
modified_body_bytes
request
.
_body
=
modified_body_bytes
...
@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
...
@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
status_code
=
r
.
status_code
,
status_code
=
r
.
status_code
,
content
=
res
,
content
=
res
,
)
)
except
:
except
Exception
:
pass
pass
else
:
else
:
pass
pass
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
...
@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
...
@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
}
}
)
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
...
@@ -1334,14 +1309,14 @@ async def chat_completed(
...
@@ -1334,14 +1309,14 @@ async def chat_completed(
)
)
model
=
app
.
state
.
MODELS
[
model_id
]
model
=
app
.
state
.
MODELS
[
model_id
]
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
"session_id"
:
data
[
"session_id"
],
"session_id"
:
data
[
"session_id"
],
}
}
)
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
...
...
backend/utils/misc.py
View file @
a9a6ed8b
from
pathlib
import
Path
from
pathlib
import
Path
import
hashlib
import
hashlib
import
json
import
re
import
re
from
datetime
import
timedelta
from
datetime
import
timedelta
from
typing
import
Optional
,
List
,
Tuple
from
typing
import
Optional
,
List
,
Tuple
...
@@ -8,37 +7,39 @@ import uuid
...
@@ -8,37 +7,39 @@ import uuid
import
time
import
time
def
get_last_user_message_item
(
messages
:
List
[
dict
])
->
str
:
def
get_last_user_message_item
(
messages
:
List
[
dict
])
->
Optional
[
dict
]
:
for
message
in
reversed
(
messages
):
for
message
in
reversed
(
messages
):
if
message
[
"role"
]
==
"user"
:
if
message
[
"role"
]
==
"user"
:
return
message
return
message
return
None
return
None
def
get_last_user_message
(
messages
:
List
[
dict
])
->
str
:
def
get_content_from_message
(
message
:
dict
)
->
Optional
[
str
]:
message
=
get_last_user_message_item
(
messages
)
if
message
is
not
None
:
if
isinstance
(
message
[
"content"
],
list
):
if
isinstance
(
message
[
"content"
],
list
):
for
item
in
message
[
"content"
]:
for
item
in
message
[
"content"
]:
if
item
[
"type"
]
==
"text"
:
if
item
[
"type"
]
==
"text"
:
return
item
[
"text"
]
return
item
[
"text"
]
else
:
return
message
[
"content"
]
return
message
[
"content"
]
return
None
return
None
def
get_last_assistant_message
(
messages
:
List
[
dict
])
->
str
:
def
get_last_user_message
(
messages
:
List
[
dict
])
->
Optional
[
str
]:
message
=
get_last_user_message_item
(
messages
)
if
message
is
None
:
return
None
return
get_content_from_message
(
message
)
def
get_last_assistant_message
(
messages
:
List
[
dict
])
->
Optional
[
str
]:
for
message
in
reversed
(
messages
):
for
message
in
reversed
(
messages
):
if
message
[
"role"
]
==
"assistant"
:
if
message
[
"role"
]
==
"assistant"
:
if
isinstance
(
message
[
"content"
],
list
):
return
get_content_from_message
(
message
)
for
item
in
message
[
"content"
]:
if
item
[
"type"
]
==
"text"
:
return
item
[
"text"
]
return
message
[
"content"
]
return
None
return
None
def
get_system_message
(
messages
:
List
[
dict
])
->
dict
:
def
get_system_message
(
messages
:
List
[
dict
])
->
Optional
[
dict
]
:
for
message
in
messages
:
for
message
in
messages
:
if
message
[
"role"
]
==
"system"
:
if
message
[
"role"
]
==
"system"
:
return
message
return
message
...
@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]:
...
@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]:
return
[
message
for
message
in
messages
if
message
[
"role"
]
!=
"system"
]
return
[
message
for
message
in
messages
if
message
[
"role"
]
!=
"system"
]
def
pop_system_message
(
messages
:
List
[
dict
])
->
Tuple
[
dict
,
List
[
dict
]]:
def
pop_system_message
(
messages
:
List
[
dict
])
->
Tuple
[
Optional
[
dict
]
,
List
[
dict
]]:
return
get_system_message
(
messages
),
remove_system_message
(
messages
)
return
get_system_message
(
messages
),
remove_system_message
(
messages
)
...
@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
...
@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return
messages
return
messages
def
stream
_message_template
(
model
:
str
,
message
:
str
):
def
openai_chat
_message_template
(
model
:
str
):
return
{
return
{
"id"
:
f
"
{
model
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"id"
:
f
"
{
model
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"created"
:
int
(
time
.
time
()),
"model"
:
model
,
"model"
:
model
,
"choices"
:
[
"choices"
:
[{
"index"
:
0
,
"logprobs"
:
None
,
"finish_reason"
:
None
}],
{
"index"
:
0
,
"delta"
:
{
"content"
:
message
},
"logprobs"
:
None
,
"finish_reason"
:
None
,
}
],
}
}
def
openai_chat_chunk_message_template
(
model
:
str
,
message
:
str
):
template
=
openai_chat_message_template
(
model
)
template
[
"object"
]
=
"chat.completion.chunk"
template
[
"choices"
][
0
][
"delta"
]
=
{
"content"
:
message
}
return
template
def
openai_chat_completion_message_template
(
model
:
str
,
message
:
str
):
template
=
openai_chat_message_template
(
model
)
template
[
"object"
]
=
"chat.completion"
template
[
"choices"
][
0
][
"message"
]
=
{
"content"
:
message
,
"role"
:
"assistant"
}
template
[
"choices"
][
0
][
"finish_reason"
]
=
"stop"
def
get_gravatar_url
(
email
):
def
get_gravatar_url
(
email
):
# Trim leading and trailing whitespace from
# Trim leading and trailing whitespace from
# an email address and force all characters
# an email address and force all characters
...
@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path):
...
@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path):
tags
=
[]
tags
=
[]
folders
=
parts
[
index_docs
:
-
1
]
folders
=
parts
[
index_docs
:
-
1
]
for
idx
,
part
in
enumerate
(
folders
):
for
idx
,
_
in
enumerate
(
folders
):
tags
.
append
(
"/"
.
join
(
folders
[:
idx
+
1
]))
tags
.
append
(
"/"
.
join
(
folders
[:
idx
+
1
]))
return
tags
return
tags
...
@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text):
...
@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text):
value
=
param_match
.
group
(
1
)
value
=
param_match
.
group
(
1
)
try
:
try
:
if
param_type
==
int
:
if
param_type
is
int
:
value
=
int
(
value
)
value
=
int
(
value
)
elif
param_type
==
float
:
elif
param_type
is
float
:
value
=
float
(
value
)
value
=
float
(
value
)
elif
param_type
==
bool
:
elif
param_type
is
bool
:
value
=
value
.
lower
()
==
"true"
value
=
value
.
lower
()
==
"true"
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment