Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
6b8a7b99
Commit
6b8a7b99
authored
Jun 20, 2024
by
Timothy J. Baek
Browse files
refac: chat completion middleware
parent
448ca9d8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
22 deletions
+20
-22
backend/main.py
backend/main.py
+20
-22
No files found.
backend/main.py
View file @
6b8a7b99
...
@@ -316,7 +316,7 @@ async def get_function_call_response(
...
@@ -316,7 +316,7 @@ async def get_function_call_response(
class
ChatCompletionMiddleware
(
BaseHTTPMiddleware
):
class
ChatCompletionMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
return_citations
=
False
data_items
=
[]
if
request
.
method
==
"POST"
and
(
if
request
.
method
==
"POST"
and
(
"/ollama/api/chat"
in
request
.
url
.
path
"/ollama/api/chat"
in
request
.
url
.
path
...
@@ -326,23 +326,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -326,23 +326,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Read the original request body
# Read the original request body
body
=
await
request
.
body
()
body
=
await
request
.
body
()
# Decode body to string
body_str
=
body
.
decode
(
"utf-8"
)
body_str
=
body
.
decode
(
"utf-8"
)
# Parse string to JSON
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
model_id
=
data
[
"model"
]
user
=
get_current_user
(
user
=
get_current_user
(
request
,
request
,
get_http_authorization_cred
(
request
.
headers
.
get
(
"Authorization"
)),
get_http_authorization_cred
(
request
.
headers
.
get
(
"Authorization"
)),
)
)
# Remove the citations from the body
return_citations
=
data
.
get
(
"citations"
,
False
)
if
"citations"
in
data
:
del
data
[
"citations"
]
# Set the task model
# Set the task model
task_model_id
=
data
[
"
model
"
]
task_model_id
=
model
_id
if
task_model_id
not
in
app
.
state
.
MODELS
:
if
task_model_id
not
in
app
.
state
.
MODELS
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_404_NOT_FOUND
,
status_code
=
status
.
HTTP_404_NOT_FOUND
,
...
@@ -364,12 +358,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -364,12 +358,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
):
):
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
skip_files
=
False
prompt
=
get_last_user_message
(
data
[
"messages"
])
prompt
=
get_last_user_message
(
data
[
"messages"
])
context
=
""
context
=
""
# If tool_ids field is present, call the functions
# If tool_ids field is present, call the functions
skip_files
=
False
if
"tool_ids"
in
data
:
if
"tool_ids"
in
data
:
print
(
data
[
"tool_ids"
])
print
(
data
[
"tool_ids"
])
for
tool_id
in
data
[
"tool_ids"
]:
for
tool_id
in
data
[
"tool_ids"
]:
...
@@ -415,8 +408,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -415,8 +408,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
rag_context
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
rag_context
log
.
debug
(
f
"rag_context:
{
rag_context
}
, citations:
{
citations
}
"
)
log
.
debug
(
f
"rag_context:
{
rag_context
}
, citations:
{
citations
}
"
)
else
:
return_citations
=
False
if
citations
:
data_items
.
append
({
"citations"
:
citations
})
del
data
[
"files"
]
del
data
[
"files"
]
...
@@ -426,7 +420,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -426,7 +420,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
)
print
(
system_prompt
)
print
(
system_prompt
)
data
[
"messages"
]
=
add_or_update_system_message
(
data
[
"messages"
]
=
add_or_update_system_message
(
f
"
\n
{
system_prompt
}
"
,
data
[
"messages"
]
system_prompt
,
data
[
"messages"
]
)
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
...
@@ -444,18 +438,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -444,18 +438,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
response
=
await
call_next
(
request
)
response
=
await
call_next
(
request
)
if
return_citati
ons
:
# If there are data_items to inject into the resp
ons
e
# Inject the citations into the response
if
len
(
data_items
)
>
0
:
if
isinstance
(
response
,
StreamingResponse
):
if
isinstance
(
response
,
StreamingResponse
):
# If it's a streaming response, inject it as SSE event or NDJSON line
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type
=
response
.
headers
.
get
(
"Content-Type"
)
content_type
=
response
.
headers
.
get
(
"Content-Type"
)
if
"text/event-stream"
in
content_type
:
if
"text/event-stream"
in
content_type
:
return
StreamingResponse
(
return
StreamingResponse
(
self
.
openai_stream_wrapper
(
response
.
body_iterator
,
citation
s
),
self
.
openai_stream_wrapper
(
response
.
body_iterator
,
data_item
s
),
)
)
if
"application/x-ndjson"
in
content_type
:
if
"application/x-ndjson"
in
content_type
:
return
StreamingResponse
(
return
StreamingResponse
(
self
.
ollama_stream_wrapper
(
response
.
body_iterator
,
citation
s
),
self
.
ollama_stream_wrapper
(
response
.
body_iterator
,
data_item
s
),
)
)
return
response
return
response
...
@@ -463,13 +457,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -463,13 +457,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
async
def
_receive
(
self
,
body
:
bytes
):
async
def
_receive
(
self
,
body
:
bytes
):
return
{
"type"
:
"http.request"
,
"body"
:
body
,
"more_body"
:
False
}
return
{
"type"
:
"http.request"
,
"body"
:
body
,
"more_body"
:
False
}
async
def
openai_stream_wrapper
(
self
,
original_generator
,
citations
):
async
def
openai_stream_wrapper
(
self
,
original_generator
,
data_items
):
yield
f
"data:
{
json
.
dumps
(
{
'citations'
:
citations
}
)
}
\n\n
"
for
item
in
data_items
:
yield
f
"data:
{
json
.
dumps
(
item
)
}
\n\n
"
async
for
data
in
original_generator
:
async
for
data
in
original_generator
:
yield
data
yield
data
async
def
ollama_stream_wrapper
(
self
,
original_generator
,
citations
):
async
def
ollama_stream_wrapper
(
self
,
original_generator
,
data_items
):
yield
f
"
{
json
.
dumps
(
{
'citations'
:
citations
}
)
}
\n
"
for
item
in
data_items
:
yield
f
"
{
json
.
dumps
(
item
)
}
\n
"
async
for
data
in
original_generator
:
async
for
data
in
original_generator
:
yield
data
yield
data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment