Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
a2e1ea10
Commit
a2e1ea10
authored
Jun 18, 2024
by
Timothy J. Baek
Browse files
feat: tools file handler support
parent
d6ab954f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
19 deletions
+34
-19
backend/main.py
backend/main.py
+34
-19
No files found.
backend/main.py
View file @
a2e1ea10
...
@@ -241,6 +241,12 @@ async def get_function_call_response(
...
@@ -241,6 +241,12 @@ async def get_function_call_response(
toolkit_module
=
load_toolkit_module_by_id
(
tool_id
)
toolkit_module
=
load_toolkit_module_by_id
(
tool_id
)
webui_app
.
state
.
TOOLS
[
tool_id
]
=
toolkit_module
webui_app
.
state
.
TOOLS
[
tool_id
]
=
toolkit_module
file_handler
=
False
# check if toolkit_module has file_handler self variable
if
hasattr
(
toolkit_module
,
"file_handler"
):
file_handler
=
True
print
(
"file_handler: "
,
file_handler
)
function
=
getattr
(
toolkit_module
,
result
[
"name"
])
function
=
getattr
(
toolkit_module
,
result
[
"name"
])
function_result
=
None
function_result
=
None
try
:
try
:
...
@@ -279,12 +285,12 @@ async def get_function_call_response(
...
@@ -279,12 +285,12 @@ async def get_function_call_response(
print
(
e
)
print
(
e
)
# Add the function result to the system prompt
# Add the function result to the system prompt
if
function_result
:
if
function_result
is
not
None
:
return
function_result
return
function_result
,
file_handler
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
print
(
f
"Error:
{
e
}
"
)
return
None
return
None
,
False
class
ChatCompletionMiddleware
(
BaseHTTPMiddleware
):
class
ChatCompletionMiddleware
(
BaseHTTPMiddleware
):
...
@@ -340,12 +346,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -340,12 +346,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context
=
""
context
=
""
# If tool_ids field is present, call the functions
# If tool_ids field is present, call the functions
skip_files
=
False
if
"tool_ids"
in
data
:
if
"tool_ids"
in
data
:
print
(
data
[
"tool_ids"
])
print
(
data
[
"tool_ids"
])
for
tool_id
in
data
[
"tool_ids"
]:
for
tool_id
in
data
[
"tool_ids"
]:
print
(
tool_id
)
print
(
tool_id
)
try
:
try
:
response
=
await
get_function_call_response
(
response
,
file_handler
=
await
get_function_call_response
(
messages
=
data
[
"messages"
],
messages
=
data
[
"messages"
],
files
=
data
.
get
(
"files"
,
[]),
files
=
data
.
get
(
"files"
,
[]),
tool_id
=
tool_id
,
tool_id
=
tool_id
,
...
@@ -354,34 +362,41 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -354,34 +362,41 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user
=
user
,
user
=
user
,
)
)
print
(
file_handler
)
if
isinstance
(
response
,
str
):
if
isinstance
(
response
,
str
):
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
response
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
response
if
file_handler
:
skip_files
=
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
print
(
f
"Error:
{
e
}
"
)
del
data
[
"tool_ids"
]
del
data
[
"tool_ids"
]
print
(
f
"tool_context:
{
context
}
"
)
print
(
f
"tool_context:
{
context
}
"
)
# TODO: Check if tools & functions have files support to skip this step to delegate file processing
# If files field is present, generate RAG completions
# If files field is present, generate RAG completions
# If skip_files is True, skip the RAG completions
if
"files"
in
data
:
if
"files"
in
data
:
data
=
{
**
data
}
if
not
skip_files
:
rag_context
,
citations
=
get_rag_context
(
data
=
{
**
data
}
files
=
data
[
"files"
],
rag_context
,
citations
=
get_rag_context
(
messages
=
data
[
"messages"
],
files
=
data
[
"files"
],
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
messages
=
data
[
"messages"
],
k
=
rag_app
.
state
.
config
.
TOP_K
,
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
reranking_function
=
rag_app
.
state
.
sentence_transformer_rf
,
k
=
rag_app
.
state
.
config
.
TOP_K
,
r
=
rag_app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
reranking_function
=
rag_app
.
state
.
sentence_transformer_rf
,
hybrid_search
=
rag_app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
r
=
rag_app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
)
hybrid_search
=
rag_app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
)
if
rag_context
:
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
rag_context
if
rag_context
:
log
.
debug
(
f
"rag_context:
{
rag_context
}
, citations:
{
citations
}
"
)
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
rag_context
else
:
return_citations
=
False
del
data
[
"files"
]
del
data
[
"files"
]
log
.
debug
(
f
"rag_context:
{
rag_context
}
, citations:
{
citations
}
"
)
if
context
!=
""
:
if
context
!=
""
:
system_prompt
=
rag_template
(
system_prompt
=
rag_template
(
...
@@ -968,7 +983,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
...
@@ -968,7 +983,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try
:
try
:
context
=
await
get_function_call_response
(
context
,
file_handler
=
await
get_function_call_response
(
form_data
[
"messages"
],
form_data
[
"messages"
],
form_data
.
get
(
"files"
,
[]),
form_data
.
get
(
"files"
,
[]),
form_data
[
"tool_id"
],
form_data
[
"tool_id"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment