Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
3978efd7
Commit
3978efd7
authored
Jul 31, 2024
by
Michael Poluektov
Browse files
refac: Refactor functions
parent
9d58bb1c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
219 additions
and
296 deletions
+219
-296
backend/apps/socket/main.py
backend/apps/socket/main.py
+2
-4
backend/apps/webui/main.py
backend/apps/webui/main.py
+207
-269
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+2
-10
backend/main.py
backend/main.py
+8
-13
No files found.
backend/apps/socket/main.py
View file @
3978efd7
...
@@ -52,7 +52,6 @@ async def user_join(sid, data):
...
@@ -52,7 +52,6 @@ async def user_join(sid, data):
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
if
user
:
if
user
:
SESSION_POOL
[
sid
]
=
user
.
id
SESSION_POOL
[
sid
]
=
user
.
id
if
user
.
id
in
USER_POOL
:
if
user
.
id
in
USER_POOL
:
USER_POOL
[
user
.
id
].
append
(
sid
)
USER_POOL
[
user
.
id
].
append
(
sid
)
...
@@ -80,7 +79,6 @@ def get_models_in_use():
...
@@ -80,7 +79,6 @@ def get_models_in_use():
@
sio
.
on
(
"usage"
)
@
sio
.
on
(
"usage"
)
async
def
usage
(
sid
,
data
):
async
def
usage
(
sid
,
data
):
model_id
=
data
[
"model"
]
model_id
=
data
[
"model"
]
# Cancel previous callback if there is one
# Cancel previous callback if there is one
...
@@ -139,7 +137,7 @@ async def disconnect(sid):
...
@@ -139,7 +137,7 @@ async def disconnect(sid):
print
(
f
"Unknown session ID
{
sid
}
disconnected"
)
print
(
f
"Unknown session ID
{
sid
}
disconnected"
)
async
def
get_event_emitter
(
request_info
):
def
get_event_emitter
(
request_info
):
async
def
__event_emitter__
(
event_data
):
async
def
__event_emitter__
(
event_data
):
await
sio
.
emit
(
await
sio
.
emit
(
"chat-events"
,
"chat-events"
,
...
@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
...
@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
return
__event_emitter__
return
__event_emitter__
async
def
get_event_call
(
request_info
):
def
get_event_call
(
request_info
):
async
def
__event_call__
(
event_data
):
async
def
__event_call__
(
event_data
):
response
=
await
sio
.
call
(
response
=
await
sio
.
call
(
"chat-events"
,
"chat-events"
,
...
...
backend/apps/webui/main.py
View file @
3978efd7
from
fastapi
import
FastAPI
,
Depends
from
fastapi
import
FastAPI
from
fastapi.routing
import
APIRoute
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
sqlalchemy.orm
import
Session
from
apps.webui.routers
import
(
from
apps.webui.routers
import
(
auths
,
auths
,
users
,
users
,
...
@@ -27,7 +24,6 @@ from utils.task import prompt_template
...
@@ -27,7 +24,6 @@ from utils.task import prompt_template
from
config
import
(
from
config
import
(
WEBUI_BUILD_HASH
,
SHOW_ADMIN_DETAILS
,
SHOW_ADMIN_DETAILS
,
ADMIN_EMAIL
,
ADMIN_EMAIL
,
WEBUI_AUTH
,
WEBUI_AUTH
,
...
@@ -55,7 +51,7 @@ import uuid
...
@@ -55,7 +51,7 @@ import uuid
import
time
import
time
import
json
import
json
from
typing
import
Iterator
,
Generator
,
AsyncGenerator
,
Optional
from
typing
import
Iterator
,
Generator
,
AsyncGenerator
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
app
=
FastAPI
()
app
=
FastAPI
()
...
@@ -127,29 +123,31 @@ async def get_status():
...
@@ -127,29 +123,31 @@ async def get_status():
}
}
def
get_function_module
(
pipe_id
:
str
):
# Check if function is already loaded
if
pipe_id
not
in
app
.
state
.
FUNCTIONS
:
function_module
,
_
,
_
=
load_function_module_by_id
(
pipe_id
)
app
.
state
.
FUNCTIONS
[
pipe_id
]
=
function_module
else
:
function_module
=
app
.
state
.
FUNCTIONS
[
pipe_id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
pipe_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{}))
return
function_module
async
def
get_pipe_models
():
async
def
get_pipe_models
():
pipes
=
Functions
.
get_functions_by_type
(
"pipe"
,
active_only
=
True
)
pipes
=
Functions
.
get_functions_by_type
(
"pipe"
,
active_only
=
True
)
pipe_models
=
[]
pipe_models
=
[]
for
pipe
in
pipes
:
for
pipe
in
pipes
:
# Check if function is already loaded
function_module
=
get_function_module
(
pipe
.
id
)
if
pipe
.
id
not
in
app
.
state
.
FUNCTIONS
:
function_module
,
function_type
,
frontmatter
=
load_function_module_by_id
(
pipe
.
id
)
app
.
state
.
FUNCTIONS
[
pipe
.
id
]
=
function_module
else
:
function_module
=
app
.
state
.
FUNCTIONS
[
pipe
.
id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
pipe
.
id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
# Check if function is a manifold
# Check if function is a manifold
if
hasattr
(
function_module
,
"type"
):
if
hasattr
(
function_module
,
"type"
):
if
function_module
.
type
==
"manifold"
:
if
not
function_module
.
type
==
"manifold"
:
continue
manifold_pipes
=
[]
manifold_pipes
=
[]
# Check if pipes is a function or a list
# Check if pipes is a function or a list
...
@@ -163,9 +161,7 @@ async def get_pipe_models():
...
@@ -163,9 +161,7 @@ async def get_pipe_models():
manifold_pipe_name
=
p
[
"name"
]
manifold_pipe_name
=
p
[
"name"
]
if
hasattr
(
function_module
,
"name"
):
if
hasattr
(
function_module
,
"name"
):
manifold_pipe_name
=
(
manifold_pipe_name
=
f
"
{
function_module
.
name
}{
manifold_pipe_name
}
"
f
"
{
function_module
.
name
}{
manifold_pipe_name
}
"
)
pipe_flag
=
{
"type"
:
pipe
.
type
}
pipe_flag
=
{
"type"
:
pipe
.
type
}
if
hasattr
(
function_module
,
"ChatValves"
):
if
hasattr
(
function_module
,
"ChatValves"
):
...
@@ -200,127 +196,81 @@ async def get_pipe_models():
...
@@ -200,127 +196,81 @@ async def get_pipe_models():
return
pipe_models
return
pipe_models
async
def
generate_function_chat_completion
(
form_data
,
user
):
async
def
execute_pipe
(
pipe
,
params
):
model_id
=
form_data
.
get
(
"model"
)
if
inspect
.
iscoroutinefunction
(
pipe
):
model_info
=
Models
.
get_model_by_id
(
model_id
)
return
await
pipe
(
**
params
)
else
:
return
pipe
(
**
params
)
metadata
=
None
if
"metadata"
in
form_data
:
metadata
=
form_data
[
"metadata"
]
del
form_data
[
"metadata"
]
__event_emitter__
=
None
async
def
get_message
(
res
:
str
|
Generator
|
AsyncGenerator
)
->
str
:
__event_call__
=
None
if
isinstance
(
res
,
str
):
__task__
=
None
return
res
if
isinstance
(
res
,
Generator
):
return
""
.
join
(
map
(
str
,
res
))
if
isinstance
(
res
,
AsyncGenerator
):
return
""
.
join
([
str
(
stream
)
async
for
stream
in
res
])
if
metadata
:
if
(
metadata
.
get
(
"session_id"
)
and
metadata
.
get
(
"chat_id"
)
and
metadata
.
get
(
"message_id"
)
):
__event_emitter__
=
await
get_event_emitter
(
metadata
)
__event_call__
=
await
get_event_call
(
metadata
)
if
metadata
.
get
(
"task"
):
__task__
=
metadata
.
get
(
"task"
)
if
model_info
:
if
model_info
.
base_model_id
:
form_data
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
def
get_final_message
(
form_data
:
dict
,
message
:
str
|
None
=
None
)
->
dict
:
choice
=
{
"index"
:
0
,
"logprobs"
:
None
,
"finish_reason"
:
"stop"
,
}
if
model_info
.
params
:
# If message is None, we're dealing with a chunk
if
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
:
if
not
message
:
form_data
[
"temperature"
]
=
float
(
model_info
.
params
.
get
(
"temperature"
))
choice
[
"delta"
]
=
{}
else
:
choice
[
"message"
]
=
{
"role"
:
"assistant"
,
"content"
:
message
}
if
model_info
.
params
.
get
(
"top_p"
,
None
):
return
{
form_data
[
"top_p"
]
=
int
(
model_info
.
params
.
get
(
"top_p"
,
None
))
"id"
:
f
"
{
form_data
[
'model'
]
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"created"
:
int
(
time
.
time
()),
"model"
:
form_data
[
"model"
],
"object"
:
"chat.completion"
if
message
is
not
None
else
"chat.completion.chunk"
,
"choices"
:
[
choice
],
}
if
model_info
.
params
.
get
(
"max_tokens"
,
None
):
form_data
[
"max_tokens"
]
=
int
(
model_info
.
params
.
get
(
"max_tokens"
,
None
))
if
model_info
.
params
.
get
(
"frequency_penalty"
,
None
):
def
process_line
(
form_data
:
dict
,
line
):
form_data
[
"frequency_penalty"
]
=
int
(
if
isinstance
(
line
,
BaseModel
):
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
line
=
line
.
model_dump_json
()
)
line
=
f
"data:
{
line
}
"
if
isinstance
(
line
,
dict
):
line
=
f
"data:
{
json
.
dumps
(
line
)
}
"
if
model_info
.
params
.
get
(
"seed"
,
None
):
try
:
form_data
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
line
=
line
.
decode
(
"utf-8"
)
except
Exception
:
if
model_info
.
params
.
get
(
"stop"
,
None
):
pass
form_data
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
line
.
startswith
(
"data:"
):
if
system
:
return
f
"
{
line
}
\n\n
"
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if
form_data
.
get
(
"messages"
):
for
message
in
form_data
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
system
+
message
[
"content"
]
break
else
:
else
:
form_data
[
"messages"
].
insert
(
line
=
stream_message_template
(
form_data
[
"model"
],
line
)
0
,
return
f
"data:
{
json
.
dumps
(
line
)
}
\n\n
"
{
"role"
:
"system"
,
"content"
:
system
,
},
)
else
:
pass
async
def
job
()
:
def
get_pipe_id
(
form_data
:
dict
)
->
str
:
pipe_id
=
form_data
[
"model"
]
pipe_id
=
form_data
[
"model"
]
if
"."
in
pipe_id
:
if
"."
in
pipe_id
:
pipe_id
,
sub_pipe_id
=
pipe_id
.
split
(
"."
,
1
)
pipe_id
,
_
=
pipe_id
.
split
(
"."
,
1
)
print
(
pipe_id
)
print
(
pipe_id
)
return
pipe_id
# Check if function is already loaded
if
pipe_id
not
in
app
.
state
.
FUNCTIONS
:
function_module
,
function_type
,
frontmatter
=
load_function_module_by_id
(
pipe_id
)
app
.
state
.
FUNCTIONS
[
pipe_id
]
=
function_module
else
:
function_module
=
app
.
state
.
FUNCTIONS
[
pipe_id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
pipe_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
pipe
=
function_module
.
pipe
def
get_params_dict
(
pipe
,
form_data
,
user
,
extra_params
,
function_module
):
pipe_id
=
get_pipe_id
(
form_data
)
# Get the signature of the function
# Get the signature of the function
sig
=
inspect
.
signature
(
pipe
)
sig
=
inspect
.
signature
(
pipe
)
params
=
{
"body"
:
form_data
}
params
=
{
"body"
:
form_data
}
for
key
,
value
in
extra_params
.
items
():
if
key
in
sig
.
parameters
:
params
[
key
]
=
value
if
"__user__"
in
sig
.
parameters
:
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
__user__
=
{
"id"
:
user
.
id
,
"id"
:
user
.
id
,
...
@@ -337,25 +287,88 @@ async def generate_function_chat_completion(form_data, user):
...
@@ -337,25 +287,88 @@ async def generate_function_chat_completion(form_data, user):
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
params
=
{
**
params
,
"__user__"
:
__user__
}
params
[
"__user__"
]
=
__user__
return
params
if
"__event_emitter__"
in
sig
.
parameters
:
async
def
generate_function_chat_completion
(
form_data
,
user
):
params
=
{
**
params
,
"__event_emitter__"
:
__event_emitter__
}
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
"__event_call__"
in
sig
.
parameters
:
metadata
=
form_data
.
pop
(
"metadata"
,
None
)
params
=
{
**
params
,
"__event_call__"
:
__event_call__
}
if
"__task__"
in
sig
.
parameters
:
__event_emitter__
=
__event_call__
=
__task__
=
None
params
=
{
**
params
,
"__task__"
:
__task__
}
if
metadata
:
if
all
(
k
in
metadata
for
k
in
(
"session_id"
,
"chat_id"
,
"message_id"
)):
__event_emitter__
=
get_event_emitter
(
metadata
)
__event_call__
=
get_event_call
(
metadata
)
__task__
=
metadata
.
get
(
"task"
,
None
)
if
not
model_info
:
return
if
model_info
.
base_model_id
:
form_data
[
"model"
]
=
model_info
.
base_model_id
params
=
model_info
.
params
.
model_dump
()
if
params
:
mappings
=
{
"temperature"
:
float
,
"top_p"
:
int
,
"max_tokens"
:
int
,
"frequency_penalty"
:
int
,
"seed"
:
lambda
x
:
x
,
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
}
for
key
,
cast_func
in
mappings
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
system
=
params
.
get
(
"system"
,
None
)
if
not
system
:
return
if
user
:
template_params
=
{
"user_name"
:
user
.
name
,
"user_location"
:
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
,
}
else
:
template_params
=
{}
system
=
prompt_template
(
system
,
**
template_params
)
# Check if the payload already has a system message
# If not, add a system message to the payload
for
message
in
form_data
.
get
(
"messages"
,
[]):
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
system
+
message
[
"content"
]
break
else
:
if
form_data
.
get
(
"messages"
):
form_data
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
system
})
extra_params
=
{
"__event_emitter__"
:
__event_emitter__
,
"__event_call__"
:
__event_call__
,
"__task__"
:
__task__
,
}
async
def
job
():
pipe_id
=
get_pipe_id
(
form_data
)
function_module
=
get_function_module
(
pipe_id
)
pipe
=
function_module
.
pipe
params
=
get_params_dict
(
pipe
,
form_data
,
user
,
extra_params
,
function_module
)
if
form_data
[
"stream"
]:
if
form_data
[
"stream"
]:
async
def
stream_content
():
async
def
stream_content
():
try
:
try
:
if
inspect
.
iscoroutinefunction
(
pipe
):
res
=
await
execute_pipe
(
pipe
,
params
)
res
=
await
pipe
(
**
params
)
else
:
res
=
pipe
(
**
params
)
# Directly return if the response is a StreamingResponse
# Directly return if the response is a StreamingResponse
if
isinstance
(
res
,
StreamingResponse
):
if
isinstance
(
res
,
StreamingResponse
):
...
@@ -377,107 +390,32 @@ async def generate_function_chat_completion(form_data, user):
...
@@ -377,107 +390,32 @@ async def generate_function_chat_completion(form_data, user):
if
isinstance
(
res
,
Iterator
):
if
isinstance
(
res
,
Iterator
):
for
line
in
res
:
for
line
in
res
:
if
isinstance
(
line
,
BaseModel
):
yield
process_line
(
form_data
,
line
)
line
=
line
.
model_dump_json
()
line
=
f
"data:
{
line
}
"
if
isinstance
(
line
,
dict
):
line
=
f
"data:
{
json
.
dumps
(
line
)
}
"
try
:
line
=
line
.
decode
(
"utf-8"
)
except
:
pass
if
line
.
startswith
(
"data:"
):
yield
f
"
{
line
}
\n\n
"
else
:
line
=
stream_message_template
(
form_data
[
"model"
],
line
)
yield
f
"data:
{
json
.
dumps
(
line
)
}
\n\n
"
if
isinstance
(
res
,
str
)
or
isinstance
(
res
,
Generator
):
finish_message
=
{
"id"
:
f
"
{
form_data
[
'model'
]
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"model"
:
form_data
[
"model"
],
"choices"
:
[
{
"index"
:
0
,
"delta"
:
{},
"logprobs"
:
None
,
"finish_reason"
:
"stop"
,
}
],
}
yield
f
"data:
{
json
.
dumps
(
finish_message
)
}
\n\n
"
yield
f
"data: [DONE]"
if
isinstance
(
res
,
AsyncGenerator
):
if
isinstance
(
res
,
AsyncGenerator
):
async
for
line
in
res
:
async
for
line
in
res
:
if
isinstance
(
line
,
BaseModel
):
yield
process_line
(
form_data
,
line
)
line
=
line
.
model_dump_json
()
line
=
f
"data:
{
line
}
"
if
isinstance
(
line
,
dict
):
line
=
f
"data:
{
json
.
dumps
(
line
)
}
"
try
:
if
isinstance
(
res
,
str
)
or
isinstance
(
res
,
Generator
):
line
=
line
.
decode
(
"utf-8"
)
finish_message
=
get_final_message
(
form_data
)
except
:
yield
f
"data:
{
json
.
dumps
(
finish_message
)
}
\n\n
"
pass
yield
"data: [DONE]"
if
line
.
startswith
(
"data:"
):
yield
f
"
{
line
}
\n\n
"
else
:
line
=
stream_message_template
(
form_data
[
"model"
],
line
)
yield
f
"data:
{
json
.
dumps
(
line
)
}
\n\n
"
return
StreamingResponse
(
stream_content
(),
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
stream_content
(),
media_type
=
"text/event-stream"
)
else
:
else
:
try
:
try
:
if
inspect
.
iscoroutinefunction
(
pipe
):
res
=
await
execute_pipe
(
pipe
,
params
)
res
=
await
pipe
(
**
params
)
else
:
res
=
pipe
(
**
params
)
if
isinstance
(
res
,
StreamingResponse
):
return
res
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
print
(
f
"Error:
{
e
}
"
)
return
{
"error"
:
{
"detail"
:
str
(
e
)}}
return
{
"error"
:
{
"detail"
:
str
(
e
)}}
if
isinstance
(
res
,
dict
):
if
isinstance
(
res
,
StreamingResponse
)
or
isinstance
(
res
,
dict
):
return
res
return
res
el
if
isinstance
(
res
,
BaseModel
):
if
isinstance
(
res
,
BaseModel
):
return
res
.
model_dump
()
return
res
.
model_dump
()
else
:
message
=
""
if
isinstance
(
res
,
str
):
message
=
res
elif
isinstance
(
res
,
Generator
):
for
stream
in
res
:
message
=
f
"
{
message
}{
stream
}
"
elif
isinstance
(
res
,
AsyncGenerator
):
async
for
stream
in
res
:
message
=
f
"
{
message
}{
stream
}
"
return
{
message
=
await
get_message
(
res
)
"id"
:
f
"
{
form_data
[
'model'
]
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
return
get_final_message
(
form_data
,
message
)
"object"
:
"chat.completion"
,
"created"
:
int
(
time
.
time
()),
"model"
:
form_data
[
"model"
],
"choices"
:
[
{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
message
,
},
"logprobs"
:
None
,
"finish_reason"
:
"stop"
,
}
],
}
return
await
job
()
return
await
job
()
backend/apps/webui/models/models.py
View file @
3978efd7
import
json
import
logging
import
logging
from
typing
import
Optional
from
typing
import
Optional
,
List
from
pydantic
import
BaseModel
,
ConfigDict
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
sqlalchemy
import
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
from
config
import
SRC_LOG_LEVELS
import
time
import
time
...
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
...
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class
ModelsTable
:
class
ModelsTable
:
def
insert_new_model
(
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
self
,
form_data
:
ModelForm
,
user_id
:
str
)
->
Optional
[
ModelModel
]:
)
->
Optional
[
ModelModel
]:
...
@@ -126,9 +123,7 @@ class ModelsTable:
...
@@ -126,9 +123,7 @@ class ModelsTable:
}
}
)
)
try
:
try
:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
add
(
result
)
db
.
commit
()
db
.
commit
()
...
@@ -144,13 +139,11 @@ class ModelsTable:
...
@@ -144,13 +139,11 @@ class ModelsTable:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
try
:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
model
=
db
.
get
(
Model
,
id
)
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
return
ModelModel
.
model_validate
(
model
)
except
:
except
:
...
@@ -178,7 +171,6 @@ class ModelsTable:
...
@@ -178,7 +171,6 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
try
:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
db
.
commit
()
...
...
backend/main.py
View file @
3978efd7
...
@@ -13,8 +13,6 @@ import aiohttp
...
@@ -13,8 +13,6 @@ import aiohttp
import
requests
import
requests
import
mimetypes
import
mimetypes
import
shutil
import
shutil
import
os
import
uuid
import
inspect
import
inspect
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
...
@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
...
@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
apps.socket.main
import
sio
,
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.socket.main
import
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.ollama.main
import
(
from
apps.ollama.main
import
(
app
as
ollama_app
,
app
as
ollama_app
,
get_all_models
as
get_ollama_models
,
get_all_models
as
get_ollama_models
,
...
@@ -639,10 +637,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -639,10 +637,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id
=
body
[
"id"
]
message_id
=
body
[
"id"
]
del
body
[
"id"
]
del
body
[
"id"
]
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
)
...
@@ -1191,13 +1189,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
...
@@ -1191,13 +1189,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
status_code
=
r
.
status_code
,
status_code
=
r
.
status_code
,
content
=
res
,
content
=
res
,
)
)
except
:
except
Exception
:
pass
pass
else
:
else
:
pass
pass
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
...
@@ -1205,7 +1203,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
...
@@ -1205,7 +1203,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
}
}
)
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
...
@@ -1334,14 +1332,14 @@ async def chat_completed(
...
@@ -1334,14 +1332,14 @@ async def chat_completed(
)
)
model
=
app
.
state
.
MODELS
[
model_id
]
model
=
app
.
state
.
MODELS
[
model_id
]
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
"session_id"
:
data
[
"session_id"
],
"session_id"
:
data
[
"session_id"
],
}
}
)
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
...
@@ -1770,7 +1768,6 @@ class AddPipelineForm(BaseModel):
...
@@ -1770,7 +1768,6 @@ class AddPipelineForm(BaseModel):
@
app
.
post
(
"/api/pipelines/add"
)
@
app
.
post
(
"/api/pipelines/add"
)
async
def
add_pipeline
(
form_data
:
AddPipelineForm
,
user
=
Depends
(
get_admin_user
)):
async
def
add_pipeline
(
form_data
:
AddPipelineForm
,
user
=
Depends
(
get_admin_user
)):
r
=
None
r
=
None
try
:
try
:
urlIdx
=
form_data
.
urlIdx
urlIdx
=
form_data
.
urlIdx
...
@@ -1813,7 +1810,6 @@ class DeletePipelineForm(BaseModel):
...
@@ -1813,7 +1810,6 @@ class DeletePipelineForm(BaseModel):
@
app
.
delete
(
"/api/pipelines/delete"
)
@
app
.
delete
(
"/api/pipelines/delete"
)
async
def
delete_pipeline
(
form_data
:
DeletePipelineForm
,
user
=
Depends
(
get_admin_user
)):
async
def
delete_pipeline
(
form_data
:
DeletePipelineForm
,
user
=
Depends
(
get_admin_user
)):
r
=
None
r
=
None
try
:
try
:
urlIdx
=
form_data
.
urlIdx
urlIdx
=
form_data
.
urlIdx
...
@@ -1891,7 +1887,6 @@ async def get_pipeline_valves(
...
@@ -1891,7 +1887,6 @@ async def get_pipeline_valves(
models
=
await
get_all_models
()
models
=
await
get_all_models
()
r
=
None
r
=
None
try
:
try
:
url
=
openai_app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
urlIdx
]
url
=
openai_app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
urlIdx
]
key
=
openai_app
.
state
.
config
.
OPENAI_API_KEYS
[
urlIdx
]
key
=
openai_app
.
state
.
config
.
OPENAI_API_KEYS
[
urlIdx
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment