Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
a9a6ed8b
Unverified
Commit
a9a6ed8b
authored
Aug 02, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Aug 02, 2024
Browse files
Merge pull request #4237 from michaelpoluektov/refactor-webui-main
refactor: Simplify functions
parents
64b41655
e6c64282
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
284 additions
and
389 deletions
+284
-389
backend/apps/socket/main.py
backend/apps/socket/main.py
+2
-4
backend/apps/webui/main.py
backend/apps/webui/main.py
+225
-302
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+2
-10
backend/main.py
backend/main.py
+16
-41
backend/utils/misc.py
backend/utils/misc.py
+39
-32
No files found.
backend/apps/socket/main.py
View file @
a9a6ed8b
...
...
@@ -52,7 +52,6 @@ async def user_join(sid, data):
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
if
user
:
SESSION_POOL
[
sid
]
=
user
.
id
if
user
.
id
in
USER_POOL
:
USER_POOL
[
user
.
id
].
append
(
sid
)
...
...
@@ -80,7 +79,6 @@ def get_models_in_use():
@
sio
.
on
(
"usage"
)
async
def
usage
(
sid
,
data
):
model_id
=
data
[
"model"
]
# Cancel previous callback if there is one
...
...
@@ -139,7 +137,7 @@ async def disconnect(sid):
print
(
f
"Unknown session ID
{
sid
}
disconnected"
)
async
def
get_event_emitter
(
request_info
):
def
get_event_emitter
(
request_info
):
async
def
__event_emitter__
(
event_data
):
await
sio
.
emit
(
"chat-events"
,
...
...
@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
return
__event_emitter__
async
def
get_event_call
(
request_info
):
def
get_event_call
(
request_info
):
async
def
__event_call__
(
event_data
):
response
=
await
sio
.
call
(
"chat-events"
,
...
...
backend/apps/webui/main.py
View file @
a9a6ed8b
from
fastapi
import
FastAPI
,
Depends
from
fastapi.routing
import
APIRoute
from
fastapi
import
FastAPI
from
fastapi.responses
import
StreamingResponse
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
sqlalchemy.orm
import
Session
from
apps.webui.routers
import
(
auths
,
users
,
...
...
@@ -22,12 +19,15 @@ from apps.webui.models.functions import Functions
from
apps.webui.models.models
import
Models
from
apps.webui.utils
import
load_function_module_by_id
from
utils.misc
import
stream_message_template
from
utils.misc
import
(
openai_chat_chunk_message_template
,
openai_chat_completion_message_template
,
add_or_update_system_message
,
)
from
utils.task
import
prompt_template
from
config
import
(
WEBUI_BUILD_HASH
,
SHOW_ADMIN_DETAILS
,
ADMIN_EMAIL
,
WEBUI_AUTH
,
...
...
@@ -51,11 +51,9 @@ from config import (
from
apps.socket.main
import
get_event_call
,
get_event_emitter
import
inspect
import
uuid
import
time
import
json
from
typing
import
Iterator
,
Generator
,
AsyncGenerator
,
Optional
from
typing
import
Iterator
,
Generator
,
AsyncGenerator
from
pydantic
import
BaseModel
app
=
FastAPI
()
...
...
@@ -127,29 +125,29 @@ async def get_status():
}
def
get_function_module
(
pipe_id
:
str
):
# Check if function is already loaded
if
pipe_id
not
in
app
.
state
.
FUNCTIONS
:
function_module
,
_
,
_
=
load_function_module_by_id
(
pipe_id
)
app
.
state
.
FUNCTIONS
[
pipe_id
]
=
function_module
else
:
function_module
=
app
.
state
.
FUNCTIONS
[
pipe_id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
pipe_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{}))
return
function_module
async
def
get_pipe_models
():
pipes
=
Functions
.
get_functions_by_type
(
"pipe"
,
active_only
=
True
)
pipe_models
=
[]
for
pipe
in
pipes
:
# Check if function is already loaded
if
pipe
.
id
not
in
app
.
state
.
FUNCTIONS
:
function_module
,
function_type
,
frontmatter
=
load_function_module_by_id
(
pipe
.
id
)
app
.
state
.
FUNCTIONS
[
pipe
.
id
]
=
function_module
else
:
function_module
=
app
.
state
.
FUNCTIONS
[
pipe
.
id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
pipe
.
id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
function_module
=
get_function_module
(
pipe
.
id
)
# Check if function is a manifold
if
hasattr
(
function_module
,
"type"
):
if
function_module
.
type
==
"manifold"
:
if
hasattr
(
function_module
,
"pipes"
):
manifold_pipes
=
[]
# Check if pipes is a function or a list
...
...
@@ -163,9 +161,7 @@ async def get_pipe_models():
manifold_pipe_name
=
p
[
"name"
]
if
hasattr
(
function_module
,
"name"
):
manifold_pipe_name
=
(
f
"
{
function_module
.
name
}{
manifold_pipe_name
}
"
)
manifold_pipe_name
=
f
"
{
function_module
.
name
}{
manifold_pipe_name
}
"
pipe_flag
=
{
"type"
:
pipe
.
type
}
if
hasattr
(
function_module
,
"ChatValves"
):
...
...
@@ -200,127 +196,59 @@ async def get_pipe_models():
return
pipe_models
async
def
generate_function_chat_completion
(
form_data
,
user
):
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
metadata
=
None
if
"metadata"
in
form_data
:
metadata
=
form_data
[
"metadata"
]
del
form_data
[
"metadata"
]
__event_emitter__
=
None
__event_call__
=
None
__task__
=
None
if
metadata
:
if
(
metadata
.
get
(
"session_id"
)
and
metadata
.
get
(
"chat_id"
)
and
metadata
.
get
(
"message_id"
)
):
__event_emitter__
=
await
get_event_emitter
(
metadata
)
__event_call__
=
await
get_event_call
(
metadata
)
if
metadata
.
get
(
"task"
):
__task__
=
metadata
.
get
(
"task"
)
if
model_info
:
if
model_info
.
base_model_id
:
form_data
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
async
def
execute_pipe
(
pipe
,
params
):
if
inspect
.
iscoroutinefunction
(
pipe
):
return
await
pipe
(
**
params
)
else
:
return
pipe
(
**
params
)
if
model_info
.
params
:
if
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
:
form_data
[
"temperature"
]
=
float
(
model_info
.
params
.
get
(
"temperature"
))
if
model_info
.
params
.
get
(
"top_p"
,
None
):
form_data
[
"top_p"
]
=
int
(
model_info
.
params
.
get
(
"top_p"
,
None
))
async
def
get_message_content
(
res
:
str
|
Generator
|
AsyncGenerator
)
->
str
:
if
isinstance
(
res
,
str
):
return
res
if
isinstance
(
res
,
Generator
):
return
""
.
join
(
map
(
str
,
res
))
if
isinstance
(
res
,
AsyncGenerator
):
return
""
.
join
([
str
(
stream
)
async
for
stream
in
res
])
if
model_info
.
params
.
get
(
"max_tokens"
,
None
):
form_data
[
"max_tokens"
]
=
int
(
model_info
.
params
.
get
(
"max_tokens"
,
None
))
if
model_info
.
params
.
get
(
"frequency_penalty"
,
None
):
form_data
[
"frequency_penalty"
]
=
int
(
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
)
def
process_line
(
form_data
:
dict
,
line
):
if
isinstance
(
line
,
BaseModel
):
line
=
line
.
model_dump_json
()
line
=
f
"data:
{
line
}
"
if
isinstance
(
line
,
dict
):
line
=
f
"data:
{
json
.
dumps
(
line
)
}
"
if
model_info
.
params
.
get
(
"seed"
,
None
):
form_data
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
if
model_info
.
params
.
get
(
"stop"
,
None
):
form_data
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
try
:
line
=
line
.
decode
(
"utf-8"
)
except
Exception
:
pass
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if
form_data
.
get
(
"messages"
):
for
message
in
form_data
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
system
+
message
[
"content"
]
break
if
line
.
startswith
(
"data:"
):
return
f
"
{
line
}
\n\n
"
else
:
form_data
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
system
,
},
)
line
=
openai_chat_chunk_message_template
(
form_data
[
"model"
],
line
)
return
f
"data:
{
json
.
dumps
(
line
)
}
\n\n
"
else
:
pass
async
def
job
()
:
def
get_pipe_id
(
form_data
:
dict
)
->
str
:
pipe_id
=
form_data
[
"model"
]
if
"."
in
pipe_id
:
pipe_id
,
sub_pipe_id
=
pipe_id
.
split
(
"."
,
1
)
pipe_id
,
_
=
pipe_id
.
split
(
"."
,
1
)
print
(
pipe_id
)
return
pipe_id
# Check if function is already loaded
if
pipe_id
not
in
app
.
state
.
FUNCTIONS
:
function_module
,
function_type
,
frontmatter
=
load_function_module_by_id
(
pipe_id
)
app
.
state
.
FUNCTIONS
[
pipe_id
]
=
function_module
else
:
function_module
=
app
.
state
.
FUNCTIONS
[
pipe_id
]
if
hasattr
(
function_module
,
"valves"
)
and
hasattr
(
function_module
,
"Valves"
):
valves
=
Functions
.
get_function_valves_by_id
(
pipe_id
)
function_module
.
valves
=
function_module
.
Valves
(
**
(
valves
if
valves
else
{})
)
pipe
=
function_module
.
pipe
def
get_function_params
(
function_module
,
form_data
,
user
,
extra_params
=
{}):
pipe_id
=
get_pipe_id
(
form_data
)
# Get the signature of the function
sig
=
inspect
.
signature
(
pipe
)
sig
=
inspect
.
signature
(
function_module
.
pipe
)
params
=
{
"body"
:
form_data
}
for
key
,
value
in
extra_params
.
items
():
if
key
in
sig
.
parameters
:
params
[
key
]
=
value
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
"id"
:
user
.
id
,
...
...
@@ -337,25 +265,94 @@ async def generate_function_chat_completion(form_data, user):
except
Exception
as
e
:
print
(
e
)
params
=
{
**
params
,
"__user__"
:
__user__
}
params
[
"__user__"
]
=
__user__
return
params
# inplace function: form_data is modified
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
if
not
params
:
return
form_data
mappings
=
{
"temperature"
:
float
,
"top_p"
:
int
,
"max_tokens"
:
int
,
"frequency_penalty"
:
int
,
"seed"
:
lambda
x
:
x
,
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
}
for
key
,
cast_func
in
mappings
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
return
form_data
# inplace function: form_data is modified
def
apply_model_system_prompt_to_body
(
params
:
dict
,
form_data
:
dict
,
user
)
->
dict
:
system
=
params
.
get
(
"system"
,
None
)
if
not
system
:
return
form_data
if
user
:
template_params
=
{
"user_name"
:
user
.
name
,
"user_location"
:
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
,
}
else
:
template_params
=
{}
system
=
prompt_template
(
system
,
**
template_params
)
form_data
[
"messages"
]
=
add_or_update_system_message
(
system
,
form_data
.
get
(
"messages"
,
[])
)
return
form_data
if
"__event_emitter__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__event_emitter__"
:
__event_emitter__
}
async
def
generate_function_chat_completion
(
form_data
,
user
):
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
metadata
=
form_data
.
pop
(
"metadata"
,
None
)
if
"__event_call__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__event_call__"
:
__event_call__
}
__event_emitter__
=
None
__event_call__
=
None
__task__
=
None
if
"__task__"
in
sig
.
parameters
:
params
=
{
**
params
,
"__task__"
:
__task__
}
if
metadata
:
if
all
(
k
in
metadata
for
k
in
(
"session_id"
,
"chat_id"
,
"message_id"
)):
__event_emitter__
=
get_event_emitter
(
metadata
)
__event_call__
=
get_event_call
(
metadata
)
__task__
=
metadata
.
get
(
"task"
,
None
)
if
model_info
:
if
model_info
.
base_model_id
:
form_data
[
"model"
]
=
model_info
.
base_model_id
params
=
model_info
.
params
.
model_dump
()
form_data
=
apply_model_params_to_body
(
params
,
form_data
)
form_data
=
apply_model_system_prompt_to_body
(
params
,
form_data
,
user
)
pipe_id
=
get_pipe_id
(
form_data
)
function_module
=
get_function_module
(
pipe_id
)
pipe
=
function_module
.
pipe
params
=
get_function_params
(
function_module
,
form_data
,
user
,
{
"__event_emitter__"
:
__event_emitter__
,
"__event_call__"
:
__event_call__
,
"__task__"
:
__task__
,
},
)
if
form_data
[
"stream"
]:
async
def
stream_content
():
try
:
if
inspect
.
iscoroutinefunction
(
pipe
):
res
=
await
pipe
(
**
params
)
else
:
res
=
pipe
(
**
params
)
res
=
await
execute_pipe
(
pipe
,
params
)
# Directly return if the response is a StreamingResponse
if
isinstance
(
res
,
StreamingResponse
):
...
...
@@ -372,112 +369,38 @@ async def generate_function_chat_completion(form_data, user):
return
if
isinstance
(
res
,
str
):
message
=
stream
_message_template
(
form_data
[
"model"
],
res
)
message
=
openai_chat_chunk
_message_template
(
form_data
[
"model"
],
res
)
yield
f
"data:
{
json
.
dumps
(
message
)
}
\n\n
"
if
isinstance
(
res
,
Iterator
):
for
line
in
res
:
if
isinstance
(
line
,
BaseModel
):
line
=
line
.
model_dump_json
()
line
=
f
"data:
{
line
}
"
if
isinstance
(
line
,
dict
):
line
=
f
"data:
{
json
.
dumps
(
line
)
}
"
try
:
line
=
line
.
decode
(
"utf-8"
)
except
:
pass
if
line
.
startswith
(
"data:"
):
yield
f
"
{
line
}
\n\n
"
else
:
line
=
stream_message_template
(
form_data
[
"model"
],
line
)
yield
f
"data:
{
json
.
dumps
(
line
)
}
\n\n
"
if
isinstance
(
res
,
str
)
or
isinstance
(
res
,
Generator
):
finish_message
=
{
"id"
:
f
"
{
form_data
[
'model'
]
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"model"
:
form_data
[
"model"
],
"choices"
:
[
{
"index"
:
0
,
"delta"
:
{},
"logprobs"
:
None
,
"finish_reason"
:
"stop"
,
}
],
}
yield
f
"data:
{
json
.
dumps
(
finish_message
)
}
\n\n
"
yield
f
"data: [DONE]"
yield
process_line
(
form_data
,
line
)
if
isinstance
(
res
,
AsyncGenerator
):
async
for
line
in
res
:
if
isinstance
(
line
,
BaseModel
):
line
=
line
.
model_dump_json
()
line
=
f
"data:
{
line
}
"
if
isinstance
(
line
,
dict
):
line
=
f
"data:
{
json
.
dumps
(
line
)
}
"
try
:
line
=
line
.
decode
(
"utf-8"
)
except
:
pass
yield
process_line
(
form_data
,
line
)
if
line
.
startswith
(
"data:"
):
yield
f
"
{
line
}
\n\n
"
else
:
line
=
stream_message_template
(
form_data
[
"model"
],
line
)
yield
f
"data:
{
json
.
dumps
(
line
)
}
\n\n
"
if
isinstance
(
res
,
str
)
or
isinstance
(
res
,
Generator
):
finish_message
=
openai_chat_chunk_message_template
(
form_data
[
"model"
],
""
)
finish_message
[
"choices"
][
0
][
"finish_reason"
]
=
"stop"
yield
f
"data:
{
json
.
dumps
(
finish_message
)
}
\n\n
"
yield
"data: [DONE]"
return
StreamingResponse
(
stream_content
(),
media_type
=
"text/event-stream"
)
else
:
try
:
if
inspect
.
iscoroutinefunction
(
pipe
):
res
=
await
pipe
(
**
params
)
else
:
res
=
pipe
(
**
params
)
res
=
await
execute_pipe
(
pipe
,
params
)
if
isinstance
(
res
,
StreamingResponse
):
return
res
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
return
{
"error"
:
{
"detail"
:
str
(
e
)}}
if
isinstance
(
res
,
dict
):
if
isinstance
(
res
,
StreamingResponse
)
or
isinstance
(
res
,
dict
):
return
res
el
if
isinstance
(
res
,
BaseModel
):
if
isinstance
(
res
,
BaseModel
):
return
res
.
model_dump
()
else
:
message
=
""
if
isinstance
(
res
,
str
):
message
=
res
elif
isinstance
(
res
,
Generator
):
for
stream
in
res
:
message
=
f
"
{
message
}{
stream
}
"
elif
isinstance
(
res
,
AsyncGenerator
):
async
for
stream
in
res
:
message
=
f
"
{
message
}{
stream
}
"
return
{
"id"
:
f
"
{
form_data
[
'model'
]
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"object"
:
"chat.completion"
,
"created"
:
int
(
time
.
time
()),
"model"
:
form_data
[
"model"
],
"choices"
:
[
{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
message
,
},
"logprobs"
:
None
,
"finish_reason"
:
"stop"
,
}
],
}
return
await
job
()
message
=
await
get_message_content
(
res
)
return
openai_chat_completion_message_template
(
form_data
[
"model"
],
message
)
backend/apps/webui/models/models.py
View file @
a9a6ed8b
import
json
import
logging
from
typing
import
Optional
from
typing
import
Optional
,
List
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
sqlalchemy
import
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
import
time
...
...
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class
ModelsTable
:
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
)
->
Optional
[
ModelModel
]:
...
...
@@ -126,9 +123,7 @@ class ModelsTable:
}
)
try
:
with
get_db
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
...
...
@@ -144,13 +139,11 @@ class ModelsTable:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
with
get_db
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
with
get_db
()
as
db
:
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
except
:
...
...
@@ -178,7 +171,6 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_db
()
as
db
:
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
...
...
backend/main.py
View file @
a9a6ed8b
...
...
@@ -13,8 +13,6 @@ import aiohttp
import
requests
import
mimetypes
import
shutil
import
os
import
uuid
import
inspect
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
...
...
@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
apps.socket.main
import
sio
,
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.socket.main
import
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.ollama.main
import
(
app
as
ollama_app
,
get_all_models
as
get_ollama_models
,
...
...
@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content
=
{
"detail"
:
str
(
e
)},
)
# Extract valves from the request body
valves
=
None
if
"valves"
in
body
:
valves
=
body
[
"valves"
]
del
body
[
"valves"
]
metadata
=
{
"chat_id"
:
body
.
pop
(
"chat_id"
,
None
),
"message_id"
:
body
.
pop
(
"id"
,
None
),
"session_id"
:
body
.
pop
(
"session_id"
,
None
),
"valves"
:
body
.
pop
(
"valves"
,
None
),
}
# Extract session_id, chat_id and message_id from the request body
session_id
=
None
if
"session_id"
in
body
:
session_id
=
body
[
"session_id"
]
del
body
[
"session_id"
]
chat_id
=
None
if
"chat_id"
in
body
:
chat_id
=
body
[
"chat_id"
]
del
body
[
"chat_id"
]
message_id
=
None
if
"id"
in
body
:
message_id
=
body
[
"id"
]
del
body
[
"id"
]
__event_emitter__
=
await
get_event_emitter
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
__event_call__
=
await
get_event_call
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
__event_emitter__
=
get_event_emitter
(
metadata
)
__event_call__
=
get_event_call
(
metadata
)
# Initialize data_items to store additional data to be sent to the client
data_items
=
[]
...
...
@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if
len
(
citations
)
>
0
:
data_items
.
append
({
"citations"
:
citations
})
body
[
"metadata"
]
=
{
"session_id"
:
session_id
,
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"valves"
:
valves
,
}
body
[
"metadata"
]
=
metadata
modified_body_bytes
=
json
.
dumps
(
body
).
encode
(
"utf-8"
)
# Replace the request body with the modified one
request
.
_body
=
modified_body_bytes
...
...
@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
status_code
=
r
.
status_code
,
content
=
res
,
)
except
:
except
Exception
:
pass
else
:
pass
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
...
...
@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
}
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
...
...
@@ -1334,14 +1309,14 @@ async def chat_completed(
)
model
=
app
.
state
.
MODELS
[
model_id
]
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"session_id"
:
data
[
"session_id"
],
}
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
...
...
backend/utils/misc.py
View file @
a9a6ed8b
from
pathlib
import
Path
import
hashlib
import
json
import
re
from
datetime
import
timedelta
from
typing
import
Optional
,
List
,
Tuple
...
...
@@ -8,37 +7,39 @@ import uuid
import
time
def
get_last_user_message_item
(
messages
:
List
[
dict
])
->
str
:
def
get_last_user_message_item
(
messages
:
List
[
dict
])
->
Optional
[
dict
]
:
for
message
in
reversed
(
messages
):
if
message
[
"role"
]
==
"user"
:
return
message
return
None
def
get_last_user_message
(
messages
:
List
[
dict
])
->
str
:
message
=
get_last_user_message_item
(
messages
)
if
message
is
not
None
:
def
get_content_from_message
(
message
:
dict
)
->
Optional
[
str
]:
if
isinstance
(
message
[
"content"
],
list
):
for
item
in
message
[
"content"
]:
if
item
[
"type"
]
==
"text"
:
return
item
[
"text"
]
else
:
return
message
[
"content"
]
return
None
def
get_last_assistant_message
(
messages
:
List
[
dict
])
->
str
:
def
get_last_user_message
(
messages
:
List
[
dict
])
->
Optional
[
str
]:
message
=
get_last_user_message_item
(
messages
)
if
message
is
None
:
return
None
return
get_content_from_message
(
message
)
def
get_last_assistant_message
(
messages
:
List
[
dict
])
->
Optional
[
str
]:
for
message
in
reversed
(
messages
):
if
message
[
"role"
]
==
"assistant"
:
if
isinstance
(
message
[
"content"
],
list
):
for
item
in
message
[
"content"
]:
if
item
[
"type"
]
==
"text"
:
return
item
[
"text"
]
return
message
[
"content"
]
return
get_content_from_message
(
message
)
return
None
def
get_system_message
(
messages
:
List
[
dict
])
->
dict
:
def
get_system_message
(
messages
:
List
[
dict
])
->
Optional
[
dict
]
:
for
message
in
messages
:
if
message
[
"role"
]
==
"system"
:
return
message
...
...
@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]:
return
[
message
for
message
in
messages
if
message
[
"role"
]
!=
"system"
]
def
pop_system_message
(
messages
:
List
[
dict
])
->
Tuple
[
dict
,
List
[
dict
]]:
def
pop_system_message
(
messages
:
List
[
dict
])
->
Tuple
[
Optional
[
dict
]
,
List
[
dict
]]:
return
get_system_message
(
messages
),
remove_system_message
(
messages
)
...
...
@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return
messages
def
stream
_message_template
(
model
:
str
,
message
:
str
):
def
openai_chat
_message_template
(
model
:
str
):
return
{
"id"
:
f
"
{
model
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"model"
:
model
,
"choices"
:
[
{
"index"
:
0
,
"delta"
:
{
"content"
:
message
},
"logprobs"
:
None
,
"finish_reason"
:
None
,
}
],
"choices"
:
[{
"index"
:
0
,
"logprobs"
:
None
,
"finish_reason"
:
None
}],
}
def
openai_chat_chunk_message_template
(
model
:
str
,
message
:
str
):
template
=
openai_chat_message_template
(
model
)
template
[
"object"
]
=
"chat.completion.chunk"
template
[
"choices"
][
0
][
"delta"
]
=
{
"content"
:
message
}
return
template
def
openai_chat_completion_message_template
(
model
:
str
,
message
:
str
):
template
=
openai_chat_message_template
(
model
)
template
[
"object"
]
=
"chat.completion"
template
[
"choices"
][
0
][
"message"
]
=
{
"content"
:
message
,
"role"
:
"assistant"
}
template
[
"choices"
][
0
][
"finish_reason"
]
=
"stop"
def
get_gravatar_url
(
email
):
# Trim leading and trailing whitespace from
# an email address and force all characters
...
...
@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path):
tags
=
[]
folders
=
parts
[
index_docs
:
-
1
]
for
idx
,
part
in
enumerate
(
folders
):
for
idx
,
_
in
enumerate
(
folders
):
tags
.
append
(
"/"
.
join
(
folders
[:
idx
+
1
]))
return
tags
...
...
@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text):
value
=
param_match
.
group
(
1
)
try
:
if
param_type
==
int
:
if
param_type
is
int
:
value
=
int
(
value
)
elif
param_type
==
float
:
elif
param_type
is
float
:
value
=
float
(
value
)
elif
param_type
==
bool
:
elif
param_type
is
bool
:
value
=
value
.
lower
()
==
"true"
except
Exception
as
e
:
print
(
e
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment