Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
6bb2f418
Commit
6bb2f418
authored
Jun 20, 2024
by
Timothy J. Baek
Browse files
feat: tool citation
parent
58ae9136
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
17 deletions
+37
-17
backend/main.py
backend/main.py
+37
-17
No files found.
backend/main.py
View file @
6bb2f418
...
...
@@ -247,6 +247,7 @@ async def get_function_call_response(
result
=
json
.
loads
(
content
)
print
(
result
)
citation
=
None
# Call the function
if
"name"
in
result
:
if
tool_id
in
webui_app
.
state
.
TOOLS
:
...
...
@@ -309,22 +310,32 @@ async def get_function_call_response(
}
function_result
=
function
(
**
params
)
if
hasattr
(
toolkit_module
,
"citation"
)
and
toolkit_module
.
citation
:
citation
=
{
"source"
:
{
"name"
:
f
"TOOL:
{
tool
.
name
}
/
{
result
[
'name'
]
}
"
},
"document"
:
[
function_result
],
"metadata"
:
[{
"source"
:
result
[
"name"
]}],
}
except
Exception
as
e
:
print
(
e
)
# Add the function result to the system prompt
if
function_result
is
not
None
:
return
function_result
,
file_handler
return
function_result
,
citation
,
file_handler
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
return
None
,
False
return
None
,
None
,
False
class
ChatCompletionMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
data_items
=
[]
show_citations
=
False
citations
=
[]
if
request
.
method
==
"POST"
and
any
(
endpoint
in
request
.
url
.
path
for
endpoint
in
[
"/ollama/api/chat"
,
"/chat/completions"
]
...
...
@@ -342,6 +353,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files
=
False
if
data
.
get
(
"citations"
):
show_citations
=
True
del
data
[
"citations"
]
model_id
=
data
[
"model"
]
if
model_id
not
in
app
.
state
.
MODELS
:
...
...
@@ -365,8 +379,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
webui_app
.
state
.
FUNCTIONS
[
filter_id
]
=
function_module
# Check if the function has a file_handler variable
if
get
attr
(
function_module
,
"file_handler"
):
skip_files
=
True
if
has
attr
(
function_module
,
"file_handler"
):
skip_files
=
function_module
.
file_handler
try
:
if
hasattr
(
function_module
,
"inlet"
):
...
...
@@ -411,19 +425,25 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
for
tool_id
in
data
[
"tool_ids"
]:
print
(
tool_id
)
try
:
response
,
file_handler
=
await
get_function_call_response
(
messages
=
data
[
"messages"
],
files
=
data
.
get
(
"files"
,
[]),
tool_id
=
tool_id
,
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
,
task_model_id
=
task_model_id
,
user
=
user
,
response
,
citation
,
file_handler
=
(
await
get_function_call_response
(
messages
=
data
[
"messages"
],
files
=
data
.
get
(
"files"
,
[]),
tool_id
=
tool_id
,
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
,
task_model_id
=
task_model_id
,
user
=
user
,
)
)
print
(
file_handler
)
if
isinstance
(
response
,
str
):
context
+=
(
"
\n
"
if
context
!=
""
else
""
)
+
response
if
citation
:
citations
.
append
(
citation
)
show_citations
=
True
if
file_handler
:
skip_files
=
True
...
...
@@ -438,7 +458,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if
"files"
in
data
:
if
not
skip_files
:
data
=
{
**
data
}
rag_context
,
citations
=
get_rag_context
(
rag_context
,
rag_
citations
=
get_rag_context
(
files
=
data
[
"files"
],
messages
=
data
[
"messages"
],
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
...
...
@@ -452,13 +472,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
log
.
debug
(
f
"rag_context:
{
rag_context
}
, citations:
{
citations
}
"
)
if
citations
and
data
.
get
(
"
citations
"
)
:
data_items
.
append
({
"citations"
:
citations
}
)
if
rag_
citations
:
citations
.
extend
(
rag_
citations
)
del
data
[
"files"
]
if
data
.
get
(
"
citations
"
):
d
el
data
[
"citations"
]
if
show_citations
and
len
(
citations
)
>
0
:
d
ata_items
.
append
({
"citations"
:
citations
})
if
context
!=
""
:
system_prompt
=
rag_template
(
...
...
@@ -1285,7 +1305,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template
=
app
.
state
.
config
.
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try
:
context
,
file_handler
=
await
get_function_call_response
(
context
,
citation
,
file_handler
=
await
get_function_call_response
(
form_data
[
"messages"
],
form_data
.
get
(
"files"
,
[]),
form_data
[
"tool_id"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment