Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7439a8b5
Unverified
Commit
7439a8b5
authored
Dec 11, 2024
by
Clayton
Committed by
GitHub
Dec 12, 2024
Browse files
[Bugfix] Multiple fixes to tool streaming with hermes and mistral (#10979)
Signed-off-by:
cedonley
<
clayton@donley.io
>
parent
4e116833
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
21 deletions
+69
-21
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+14
-2
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
+40
-11
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+15
-8
No files found.
vllm/entrypoints/openai/serving_chat.py
View file @
7439a8b5
...
...
@@ -496,21 +496,33 @@ class OpenAIServingChat(OpenAIServing):
if
self
.
_should_check_for_unstreamed_tool_arg_tokens
(
delta_message
,
output
)
and
tool_parser
:
latest_delta_len
=
0
if
((
isinstance
(
delta_message
.
tool_calls
[
0
].
function
,
DeltaFunctionCall
))
and
isinstance
(
delta_message
.
tool_calls
[
0
].
function
.
arguments
,
str
)):
latest_delta_len
=
len
(
delta_message
.
tool_calls
[
0
].
function
.
arguments
)
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call
=
json
.
dumps
(
tool_parser
.
prev_tool_call_arr
[
index
].
get
(
"arguments"
,
{}))
"arguments"
,
{}),
ensure_ascii
=
False
)
# get what we've streamed so far for arguments
# for the current tool
actual_call
=
tool_parser
.
streamed_args_for_tool
[
index
]
if
(
latest_delta_len
>
0
):
actual_call
=
actual_call
[:
-
latest_delta_len
]
# check to see if there's anything left to stream
remaining_call
=
expected_call
.
replace
(
actual_call
,
""
,
1
)
# set that as a delta message
delta_message
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
index
,
...
...
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
View file @
7439a8b5
...
...
@@ -91,7 +91,8 @@ class Hermes2ProToolParser(ToolParser):
function
=
FunctionCall
(
name
=
function_call
[
"name"
],
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
function_call
[
"arguments"
])))
arguments
=
json
.
dumps
(
function_call
[
"arguments"
],
ensure_ascii
=
False
)))
for
function_call
in
raw_function_calls
]
...
...
@@ -139,13 +140,26 @@ class Hermes2ProToolParser(ToolParser):
self
.
tool_call_start_token_id
)
cur_tool_end_count
=
current_token_ids
.
count
(
self
.
tool_call_end_token_id
)
tool_call_portion
=
None
text_portion
=
None
# case: if we're generating text, OR rounding out a tool call
if
(
cur_tool_start_count
==
cur_tool_end_count
and
prev_tool_end_count
==
cur_tool_end_count
):
and
prev_tool_end_count
==
cur_tool_end_count
and
self
.
tool_call_end_token
not
in
delta_text
):
logger
.
debug
(
"Generating text content! skipping tool parsing."
)
if
delta_text
!=
self
.
tool_call_end_token
:
return
DeltaMessage
(
content
=
delta_text
)
return
DeltaMessage
(
content
=
delta_text
)
if
self
.
tool_call_end_token
in
delta_text
:
logger
.
debug
(
"tool_call_end_token in delta_text"
)
full_text
=
current_text
+
delta_text
tool_call_portion
=
full_text
.
split
(
self
.
tool_call_start_token
)[
-
1
].
split
(
self
.
tool_call_end_token
)[
0
].
rstrip
()
delta_text
=
delta_text
.
split
(
self
.
tool_call_end_token
)[
0
].
rstrip
()
text_portion
=
delta_text
.
split
(
self
.
tool_call_end_token
)[
-
1
].
lstrip
()
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
...
...
@@ -184,15 +198,21 @@ class Hermes2ProToolParser(ToolParser):
# case -- the current tool call is being closed.
elif
(
cur_tool_start_count
==
cur_tool_end_count
and
cur_tool_end_count
>
prev_tool_end_count
):
and
cur_tool_end_count
>=
prev_tool_end_count
):
if
(
self
.
prev_tool_call_arr
is
None
or
len
(
self
.
prev_tool_call_arr
)
==
0
):
logger
.
debug
(
"attempting to close tool call, but no tool call"
)
return
None
diff
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
if
diff
:
diff
=
diff
.
encode
(
'utf-8'
).
decode
(
'unicode_escape'
)
if
diff
is
str
else
diff
diff
=
json
.
dumps
(
diff
,
ensure_ascii
=
False
)[
len
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]):]
if
(
'"}'
not
in
delta_text
):
return
None
end_loc
=
delta_text
.
rindex
(
'"}'
)
diff
=
delta_text
[:
end_loc
]
+
'"}'
logger
.
debug
(
"Finishing tool and found diff that had not "
"been streamed yet: %s"
,
diff
)
...
...
@@ -221,10 +241,15 @@ class Hermes2ProToolParser(ToolParser):
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
return
None
except
json
.
decoder
.
JSONDecodeError
:
logger
.
debug
(
"unable to parse JSON"
)
return
None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if
not
self
.
current_tool_name_sent
:
if
(
current_tool_call
is
None
):
return
None
function_name
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
self
.
current_tool_name_sent
=
True
...
...
@@ -284,13 +309,17 @@ class Hermes2ProToolParser(ToolParser):
# autocompleting the JSON
elif
cur_arguments
and
not
prev_arguments
:
cur_arguments_json
=
json
.
dumps
(
cur_arguments
)
cur_arguments_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
logger
.
debug
(
"finding %s in %s"
,
delta_text
,
cur_arguments_json
)
# get the location where previous args differ from current
args_delta_start_loc
=
cur_arguments_json
.
index
(
delta_text
)
\
+
len
(
delta_text
)
if
(
delta_text
not
in
cur_arguments_json
[:
-
2
]):
return
None
args_delta_start_loc
=
cur_arguments_json
[:
-
2
].
\
rindex
(
delta_text
)
+
\
len
(
delta_text
)
# use that to find the actual delta
arguments_delta
=
cur_arguments_json
[:
args_delta_start_loc
]
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
7439a8b5
...
...
@@ -19,7 +19,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -109,7 +108,8 @@ class MistralToolParser(ToolParser):
function
=
FunctionCall
(
name
=
raw_function_call
[
"name"
],
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
raw_function_call
[
"arguments"
])))
arguments
=
json
.
dumps
(
raw_function_call
[
"arguments"
],
ensure_ascii
=
False
)))
for
raw_function_call
in
function_call_arr
]
...
...
@@ -199,7 +199,7 @@ class MistralToolParser(ToolParser):
diff
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"arguments"
)
if
diff
:
diff
=
json
.
dumps
(
diff
).
replace
(
diff
=
json
.
dumps
(
diff
,
ensure_ascii
=
False
).
replace
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
],
""
)
delta
=
DeltaMessage
(
tool_calls
=
[
...
...
@@ -232,7 +232,7 @@ class MistralToolParser(ToolParser):
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
f
"chatcmpl-tool-
{
random_
uu
id
()
}
"
,
id
=
MistralToolCall
.
generate_
random_id
(),
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
))
...
...
@@ -250,6 +250,8 @@ class MistralToolParser(ToolParser):
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
new_text
=
delta_text
.
replace
(
"
\'
"
,
"
\"
"
)
if
(
'"}'
in
new_text
):
new_text
=
new_text
[:
new_text
.
rindex
(
'"}'
)]
if
not
cur_arguments
and
not
prev_arguments
:
...
...
@@ -260,12 +262,15 @@ class MistralToolParser(ToolParser):
"mid-arguments"
)
delta
=
None
elif
cur_arguments
and
not
prev_arguments
:
cur_arguments_json
=
json
.
dumps
(
cur_arguments
)
cur_arguments_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)[:
-
2
]
logger
.
debug
(
"finding %s in %s"
,
new_text
,
cur_arguments_json
)
if
(
new_text
not
in
cur_arguments_json
):
return
None
arguments_delta
=
cur_arguments_json
[:
cur_arguments_json
.
index
(
new_text
)
+
r
index
(
new_text
)
+
len
(
new_text
)]
logger
.
debug
(
"First tokens in arguments received: %s"
,
arguments_delta
)
...
...
@@ -279,8 +284,10 @@ class MistralToolParser(ToolParser):
self
.
current_tool_id
]
+=
arguments_delta
elif
cur_arguments
and
prev_arguments
:
cur_args_json
=
json
.
dumps
(
cur_arguments
)
prev_args_json
=
json
.
dumps
(
prev_arguments
)
cur_args_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
prev_args_json
=
json
.
dumps
(
prev_arguments
,
ensure_ascii
=
False
)
logger
.
debug
(
"Searching for diff between
\n
%s
\n
%s"
,
cur_args_json
,
prev_args_json
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment