Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
c7a9b5cc
Commit
c7a9b5cc
authored
Jul 01, 2024
by
Timothy J. Baek
Browse files
refac: chat completion middleware
parent
b62d2a9b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
304 additions
and
223 deletions
+304
-223
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+7
-7
backend/main.py
backend/main.py
+291
-210
src/lib/components/chat/Chat.svelte
src/lib/components/chat/Chat.svelte
+6
-6
No files found.
backend/apps/rag/utils.py
View file @
c7a9b5cc
...
...
@@ -294,15 +294,17 @@ def get_rag_context(
extracted_collections
.
extend
(
collection_names
)
context_string
=
""
contexts
=
[]
citations
=
[]
for
context
in
relevant_contexts
:
try
:
if
"documents"
in
context
:
context_string
+=
"
\n\n
"
.
join
(
contexts
.
append
(
"
\n\n
"
.
join
(
[
text
for
text
in
context
[
"documents"
][
0
]
if
text
is
not
None
]
)
)
if
"metadatas"
in
context
:
citations
.
append
(
...
...
@@ -315,9 +317,7 @@ def get_rag_context(
except
Exception
as
e
:
log
.
exception
(
e
)
context_string
=
context_string
.
strip
()
return
context_string
,
citations
return
contexts
,
citations
def
get_model_path
(
model
:
str
,
update_model
:
bool
=
False
):
...
...
backend/main.py
View file @
c7a9b5cc
...
...
@@ -213,7 +213,7 @@ origins = ["*"]
async
def
get_function_call_response
(
messages
,
files
,
tool_id
,
template
,
task_model_id
,
user
messages
,
files
,
tool_id
,
template
,
task_model_id
,
user
,
model
):
tool
=
Tools
.
get_tool_by_id
(
tool_id
)
tools_specs
=
json
.
dumps
(
tool
.
specs
,
indent
=
2
)
...
...
@@ -373,68 +373,55 @@ async def get_function_call_response(
return
None
,
None
,
False
class
ChatCompletionMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
data_items
=
[]
show_citations
=
False
citations
=
[]
if
request
.
method
==
"POST"
and
any
(
endpoint
in
request
.
url
.
path
for
endpoint
in
[
"/ollama/api/chat"
,
"/chat/completions"
]
def
get_task_model_id
(
default_model_id
):
# Set the task model
task_model_id
=
default_model_id
# Check if the user has a custom task model and use that model
if
app
.
state
.
MODELS
[
task_model_id
][
"owned_by"
]
==
"ollama"
:
if
(
app
.
state
.
config
.
TASK_MODEL
and
app
.
state
.
config
.
TASK_MODEL
in
app
.
state
.
MODELS
):
log
.
debug
(
f
"request.url.path:
{
request
.
url
.
path
}
"
)
# Read the original request body
body
=
await
request
.
body
()
body_str
=
body
.
decode
(
"utf-8"
)
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
else
:
if
(
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
and
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
in
app
.
state
.
MODELS
):
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
user
=
get_current_user
(
request
,
get_http_authorization_cred
(
request
.
headers
.
get
(
"Authorization"
)),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files
=
False
if
data
.
get
(
"citations"
):
show_citations
=
True
del
data
[
"citations"
]
return
task_model_id
model_id
=
data
[
"model"
]
if
model_id
not
in
app
.
state
.
MODELS
:
raise
HTTPException
(
status_code
=
status
.
HTTP_404_NOT_FOUND
,
detail
=
"Model not found"
,
)
model
=
app
.
state
.
MODELS
[
model_id
]
def
get_filter_function_ids
(
model
):
def
get_priority
(
function_id
):
function
=
Functions
.
get_function_by_id
(
function_id
)
if
function
is
not
None
and
hasattr
(
function
,
"valves"
):
return
(
function
.
valves
if
function
.
valves
else
{}).
get
(
"priority"
,
0
)
return
(
function
.
valves
if
function
.
valves
else
{}).
get
(
"priority"
,
0
)
return
0
filter_ids
=
[
function
.
id
for
function
in
Functions
.
get_global_filter_functions
()
]
filter_ids
=
[
function
.
id
for
function
in
Functions
.
get_global_filter_functions
()]
if
"info"
in
model
and
"meta"
in
model
[
"info"
]:
filter_ids
.
extend
(
model
[
"info"
][
"meta"
].
get
(
"filterIds"
,
[]))
filter_ids
=
list
(
set
(
filter_ids
))
enabled_filter_ids
=
[
function
.
id
for
function
in
Functions
.
get_functions_by_type
(
"filter"
,
active_only
=
True
)
for
function
in
Functions
.
get_functions_by_type
(
"filter"
,
active_only
=
True
)
]
filter_ids
=
[
filter_id
for
filter_id
in
filter_ids
if
filter_id
in
enabled_filter_ids
]
filter_ids
.
sort
(
key
=
get_priority
)
return
filter_ids
async
def
chat_completion_functions_handler
(
body
,
model
,
user
):
skip_files
=
None
filter_ids
=
get_filter_function_ids
(
model
)
for
filter_id
in
filter_ids
:
filter
=
Functions
.
get_function_by_id
(
filter_id
)
if
filter
:
...
...
@@ -464,7 +451,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Get the signature of the function
sig
=
inspect
.
signature
(
inlet
)
params
=
{
"body"
:
data
}
params
=
{
"body"
:
body
}
if
"__user__"
in
sig
.
parameters
:
__user__
=
{
...
...
@@ -499,107 +486,195 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
}
if
inspect
.
iscoroutinefunction
(
inlet
):
data
=
await
inlet
(
**
params
)
body
=
await
inlet
(
**
params
)
else
:
data
=
inlet
(
**
params
)
body
=
inlet
(
**
params
)
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
return
JSONResponse
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
content
=
{
"detail"
:
str
(
e
)},
)
raise
e
if
skip_files
:
if
"files"
in
body
:
del
body
[
"files"
]
return
body
,
{}
# Set the task model
task_model_id
=
data
[
"model"
]
# Check if the user has a custom task model and use that model
if
app
.
state
.
MODELS
[
task_model_id
][
"owned_by"
]
==
"ollama"
:
if
(
app
.
state
.
config
.
TASK_MODEL
and
app
.
state
.
config
.
TASK_MODEL
in
app
.
state
.
MODELS
):
task_model_id
=
app
.
state
.
config
.
TASK_MODEL
else
:
if
(
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
and
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
in
app
.
state
.
MODELS
):
task_model_id
=
app
.
state
.
config
.
TASK_MODEL_EXTERNAL
prompt
=
get_last_user_message
(
data
[
"messages"
])
context
=
""
async
def
chat_completion_tools_handler
(
body
,
model
,
user
):
skip_files
=
None
contexts
=
[]
citations
=
None
task_model_id
=
get_task_model_id
(
body
[
"model"
])
# If tool_ids field is present, call the functions
if
"tool_ids"
in
data
:
print
(
data
[
"tool_ids"
])
for
tool_id
in
data
[
"tool_ids"
]:
if
"tool_ids"
in
body
:
print
(
body
[
"tool_ids"
])
for
tool_id
in
body
[
"tool_ids"
]:
print
(
tool_id
)
try
:
response
,
citation
,
file_handler
=
(
await
get_function_call_response
(
messages
=
data
[
"messages"
],
files
=
data
.
get
(
"files"
,
[]),
response
,
citation
,
file_handler
=
await
get_function_call_response
(
messages
=
body
[
"messages"
],
files
=
body
.
get
(
"files"
,
[]),
tool_id
=
tool_id
,
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
,
task_model_id
=
task_model_id
,
user
=
user
,
)
model
=
model
,
)
print
(
file_handler
)
if
isinstance
(
response
,
str
):
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
response
contexts
.
append
(
response
)
if
citation
:
if
citations
is
None
:
citations
=
[
citation
]
else
:
citations
.
append
(
citation
)
show_citations
=
True
if
file_handler
:
skip_files
=
True
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
del
data
[
"tool_ids"
]
print
(
f
"tool_context:
{
context
}
"
)
# If files field is present, generate RAG completions
# If skip_files is True, skip the RAG completions
if
"files"
in
data
:
if
not
skip_files
:
data
=
{
**
data
}
rag_context
,
rag_citations
=
get_rag_context
(
files
=
data
[
"files"
],
messages
=
data
[
"messages"
],
del
body
[
"tool_ids"
]
print
(
f
"tool_contexts:
{
contexts
}
"
)
if
skip_files
:
if
"files"
in
body
:
del
body
[
"files"
]
return
body
,
{
**
({
"contexts"
:
contexts
}
if
contexts
is
not
None
else
{}),
**
({
"citations"
:
citations
}
if
citations
is
not
None
else
{}),
}
async
def
chat_completion_files_handler
(
body
):
contexts
=
[]
citations
=
None
if
"files"
in
body
:
files
=
body
[
"files"
]
del
body
[
"files"
]
contexts
,
citations
=
get_rag_context
(
files
=
files
,
messages
=
body
[
"messages"
],
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
k
=
rag_app
.
state
.
config
.
TOP_K
,
reranking_function
=
rag_app
.
state
.
sentence_transformer_rf
,
r
=
rag_app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
hybrid_search
=
rag_app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
)
if
rag_context
:
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
rag_context
log
.
debug
(
f
"rag_context:
{
rag_
context
}
, citations:
{
citations
}
"
)
log
.
debug
(
f
"rag_context
s
:
{
context
s
}
, citations:
{
citations
}
"
)
if
rag_citations
:
citations
.
extend
(
rag_citations
)
return
body
,
{
**
({
"contexts"
:
contexts
}
if
contexts
is
not
None
else
{}),
**
({
"citations"
:
citations
}
if
citations
is
not
None
else
{}),
}
del
data
[
"files"
]
if
show_citations
and
len
(
citations
)
>
0
:
data_items
.
append
({
"citations"
:
citations
})
async
def
get_body_and_model_and_user
(
request
):
# Read the original request body
body
=
await
request
.
body
()
body_str
=
body
.
decode
(
"utf-8"
)
body
=
json
.
loads
(
body_str
)
if
body_str
else
{}
model_id
=
body
[
"model"
]
if
model_id
not
in
app
.
state
.
MODELS
:
raise
"Model not found"
model
=
app
.
state
.
MODELS
[
model_id
]
if
context
!=
""
:
system_prompt
=
rag_template
(
rag_app
.
state
.
config
.
RAG_TEMPLATE
,
context
,
prompt
user
=
get_current_user
(
request
,
get_http_authorization_cred
(
request
.
headers
.
get
(
"Authorization"
)),
)
print
(
system_prompt
)
data
[
"messages"
]
=
add_or_update_system_message
(
system_prompt
,
data
[
"messages"
]
return
body
,
model
,
user
class
ChatCompletionMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
if
request
.
method
==
"POST"
and
any
(
endpoint
in
request
.
url
.
path
for
endpoint
in
[
"/ollama/api/chat"
,
"/chat/completions"
]
):
log
.
debug
(
f
"request.url.path:
{
request
.
url
.
path
}
"
)
try
:
body
,
model
,
user
=
await
get_body_and_model_and_user
(
request
)
except
Exception
as
e
:
return
JSONResponse
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
content
=
{
"detail"
:
str
(
e
)},
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
# Extract chat_id and message_id from the request body
chat_id
=
None
if
"chat_id"
in
body
:
chat_id
=
body
[
"chat_id"
]
del
body
[
"chat_id"
]
message_id
=
None
if
"id"
in
body
:
message_id
=
body
[
"id"
]
del
body
[
"id"
]
# Initialize data_items to store additional data to be sent to the client
data_items
=
[]
# Initialize context, and citations
contexts
=
[]
citations
=
[]
print
(
body
)
try
:
body
,
flags
=
await
chat_completion_functions_handler
(
body
,
model
,
user
)
except
Exception
as
e
:
return
JSONResponse
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
content
=
{
"detail"
:
str
(
e
)},
)
try
:
body
,
flags
=
await
chat_completion_tools_handler
(
body
,
model
,
user
)
contexts
.
extend
(
flags
.
get
(
"contexts"
,
[]))
citations
.
extend
(
flags
.
get
(
"citations"
,
[]))
except
Exception
as
e
:
print
(
e
)
pass
try
:
body
,
flags
=
await
chat_completion_files_handler
(
body
)
contexts
.
extend
(
flags
.
get
(
"contexts"
,
[]))
citations
.
extend
(
flags
.
get
(
"citations"
,
[]))
except
Exception
as
e
:
print
(
e
)
pass
# If context is not empty, insert it into the messages
if
len
(
contexts
)
>
0
:
context_string
=
"/n"
.
join
(
contexts
).
strip
()
prompt
=
get_last_user_message
(
body
[
"messages"
])
body
[
"messages"
]
=
add_or_update_system_message
(
rag_template
(
rag_app
.
state
.
config
.
RAG_TEMPLATE
,
context_string
,
prompt
),
body
[
"messages"
],
)
# If there are citations, add them to the data_items
if
len
(
citations
)
>
0
:
data_items
.
append
({
"citations"
:
citations
})
modified_body_bytes
=
json
.
dumps
(
body
).
encode
(
"utf-8"
)
# Replace the request body with the modified one
request
.
_body
=
modified_body_bytes
# Set custom header to ensure content-length matches new body length
...
...
@@ -721,9 +796,6 @@ def filter_pipeline(payload, user):
pass
if
"pipeline"
not
in
app
.
state
.
MODELS
[
model_id
]:
if
"chat_id"
in
payload
:
del
payload
[
"chat_id"
]
if
"title"
in
payload
:
del
payload
[
"title"
]
...
...
@@ -1225,6 +1297,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content
=
{
"detail"
:
e
.
args
[
1
]},
)
if
"chat_id"
in
payload
:
del
payload
[
"chat_id"
]
return
await
generate_chat_completions
(
form_data
=
payload
,
user
=
user
)
...
...
@@ -1285,6 +1360,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content
=
{
"detail"
:
e
.
args
[
1
]},
)
if
"chat_id"
in
payload
:
del
payload
[
"chat_id"
]
return
await
generate_chat_completions
(
form_data
=
payload
,
user
=
user
)
...
...
@@ -1349,6 +1427,9 @@ Message: """{{prompt}}"""
content
=
{
"detail"
:
e
.
args
[
1
]},
)
if
"chat_id"
in
payload
:
del
payload
[
"chat_id"
]
return
await
generate_chat_completions
(
form_data
=
payload
,
user
=
user
)
...
...
src/lib/components/chat/Chat.svelte
View file @
c7a9b5cc
...
...
@@ -665,6 +665,7 @@
await tick();
const [res, controller] = await generateChatCompletion(localStorage.token, {
stream: true,
model: model.id,
messages: messagesBody,
options: {
...
...
@@ -682,8 +683,8 @@
keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
c
itations: files.length > 0 ? true : undefine
d,
chat_id: $chat
Id
c
hat_id: $chatI
d,
id: responseMessage
Id
});
if (res && res.ok) {
...
...
@@ -912,8 +913,8 @@
const [res, controller] = await generateOpenAIChatCompletion(
localStorage.token,
{
model: model.id,
stream: true,
model: model.id,
stream_options:
model.info?.meta?.capabilities?.usage ?? false
? {
...
...
@@ -983,9 +984,8 @@
max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined,
chat_id: $chatId
chat_id: $chatId,
id: responseMessageId
},
`${WEBUI_BASE_URL}/api`
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment