Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aee4c146
Unverified
Commit
aee4c146
authored
Mar 27, 2026
by
Flora Feng
Committed by
GitHub
Mar 27, 2026
Browse files
[Bugfix] Fix Hermes tool parser when stream interval > 1 (#38168)
Signed-off-by:
sfeng33
<
4florafeng@gmail.com
>
parent
0ae89f18
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
349 additions
and
382 deletions
+349
-382
tests/tool_parsers/test_hermes_tool_parser.py
tests/tool_parsers/test_hermes_tool_parser.py
+194
-0
vllm/tool_parsers/hermes_tool_parser.py
vllm/tool_parsers/hermes_tool_parser.py
+153
-364
vllm/tool_parsers/longcat_tool_parser.py
vllm/tool_parsers/longcat_tool_parser.py
+2
-18
No files found.
tests/tool_parsers/test_hermes_tool_parser.py
View file @
aee4c146
...
@@ -152,6 +152,175 @@ def test_hermes_parser_streaming(
...
@@ -152,6 +152,175 @@ def test_hermes_parser_streaming(
}
}
def
_simulate_streaming
(
tokenizer
:
TokenizerLike
,
parser
:
ToolParser
,
request
:
ChatCompletionRequest
,
text
:
str
,
stream_interval
:
int
=
1
,
)
->
list
:
"""Simulate streaming with a given stream_interval.
Tokens are batched into chunks of `stream_interval` tokens,
mimicking how the output processor delivers them.
Returns a list of non-None DeltaMessages.
"""
tokens
=
tokenizer
.
encode
(
text
)
previous_text
=
""
delta_messages
=
[]
for
i
in
range
(
0
,
len
(
tokens
),
stream_interval
):
chunk_ids
=
tokens
[
i
:
i
+
stream_interval
]
delta_text
=
tokenizer
.
decode
(
chunk_ids
)
current_text
=
previous_text
+
delta_text
delta
=
parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
chunk_ids
,
request
=
request
,
)
previous_text
=
current_text
if
delta
is
not
None
:
delta_messages
.
append
(
delta
)
return
delta_messages
@
pytest
.
mark
.
parametrize
(
"stream_interval"
,
[
2
,
3
,
5
,
8
])
def
test_hermes_streaming_tool_call_with_stream_interval
(
qwen_tokenizer
:
TokenizerLike
,
any_chat_request
:
ChatCompletionRequest
,
stream_interval
:
int
,
)
->
None
:
"""Tool call streaming must produce correct name + args at any interval."""
text
=
(
'<tool_call>{"name": "get_current_temperature", '
'"arguments": {"location": "San Francisco", "unit": "celsius"}}'
"</tool_call>"
)
parser
=
Hermes2ProToolParser
(
qwen_tokenizer
)
deltas
=
_simulate_streaming
(
qwen_tokenizer
,
parser
,
any_chat_request
,
text
,
stream_interval
)
# Flatten all DeltaToolCalls across all deltas.
tool_deltas
=
[
tc
for
d
in
deltas
if
d
.
tool_calls
for
tc
in
d
.
tool_calls
]
assert
tool_deltas
,
"Expected at least one tool call delta"
assert
tool_deltas
[
0
].
function
.
name
==
"get_current_temperature"
# Concatenated arguments must be valid JSON matching the original.
args_str
=
""
.
join
(
tc
.
function
.
arguments
or
""
for
tc
in
tool_deltas
)
assert
json
.
loads
(
args_str
)
==
{
"location"
:
"San Francisco"
,
"unit"
:
"celsius"
,
}
@
pytest
.
mark
.
parametrize
(
"stream_interval"
,
[
2
,
3
,
5
,
8
])
def
test_hermes_streaming_content_then_tool_call_with_stream_interval
(
qwen_tokenizer
:
TokenizerLike
,
any_chat_request
:
ChatCompletionRequest
,
stream_interval
:
int
,
)
->
None
:
"""Content before a tool call must be fully streamed, then tool call."""
text
=
(
"Sure, let me check the weather."
'<tool_call>{"name": "get_weather", '
'"arguments": {"city": "NYC"}}</tool_call>'
)
parser
=
Hermes2ProToolParser
(
qwen_tokenizer
)
deltas
=
_simulate_streaming
(
qwen_tokenizer
,
parser
,
any_chat_request
,
text
,
stream_interval
)
content_deltas
=
[
d
for
d
in
deltas
if
d
.
content
]
tool_deltas
=
[
d
for
d
in
deltas
if
d
.
tool_calls
]
# Content must reconstruct the prefix.
content_str
=
""
.
join
(
d
.
content
for
d
in
content_deltas
)
assert
content_str
==
"Sure, let me check the weather."
# Tool call must be correct.
tool_calls
=
[
tc
for
d
in
tool_deltas
for
tc
in
d
.
tool_calls
]
assert
tool_calls
[
0
].
function
.
name
==
"get_weather"
args_str
=
""
.
join
(
tc
.
function
.
arguments
or
""
for
tc
in
tool_calls
)
assert
json
.
loads
(
args_str
)
==
{
"city"
:
"NYC"
}
@
pytest
.
mark
.
parametrize
(
"stream_interval"
,
[
1
,
2
,
4
])
def
test_hermes_streaming_multiple_tool_calls_with_stream_interval
(
qwen_tokenizer
:
TokenizerLike
,
any_chat_request
:
ChatCompletionRequest
,
stream_interval
:
int
,
)
->
None
:
"""Multiple sequential tool calls must each be streamed correctly."""
text
=
(
'<tool_call>{"name": "search", "arguments": {"q": "cats"}}</tool_call>'
'<tool_call>{"name": "search", "arguments": {"q": "dogs"}}</tool_call>'
)
parser
=
Hermes2ProToolParser
(
qwen_tokenizer
)
deltas
=
_simulate_streaming
(
qwen_tokenizer
,
parser
,
any_chat_request
,
text
,
stream_interval
)
# Flatten all DeltaToolCalls across all deltas.
all_tool_calls
=
[
tc
for
d
in
deltas
if
d
.
tool_calls
for
tc
in
d
.
tool_calls
]
# Separate by tool index.
tool0
=
[
tc
for
tc
in
all_tool_calls
if
tc
.
index
==
0
]
tool1
=
[
tc
for
tc
in
all_tool_calls
if
tc
.
index
==
1
]
assert
tool0
[
0
].
function
.
name
==
"search"
args0
=
""
.
join
(
tc
.
function
.
arguments
or
""
for
tc
in
tool0
)
assert
json
.
loads
(
args0
)
==
{
"q"
:
"cats"
}
assert
tool1
[
0
].
function
.
name
==
"search"
args1
=
""
.
join
(
tc
.
function
.
arguments
or
""
for
tc
in
tool1
)
assert
json
.
loads
(
args1
)
==
{
"q"
:
"dogs"
}
@
pytest
.
mark
.
parametrize
(
"stream_interval"
,
[
2
,
5
])
def
test_hermes_streaming_boolean_args_with_stream_interval
(
qwen_tokenizer
:
TokenizerLike
,
any_chat_request
:
ChatCompletionRequest
,
stream_interval
:
int
,
)
->
None
:
"""Regression test for bug #19056 with stream_interval > 1."""
text
=
(
"<tool_call>
\n
"
'{"name": "final_answer", "arguments": {"trigger": true}}
\n
'
"</tool_call>"
)
parser
=
Hermes2ProToolParser
(
qwen_tokenizer
)
deltas
=
_simulate_streaming
(
qwen_tokenizer
,
parser
,
any_chat_request
,
text
,
stream_interval
)
tool_calls
=
[
tc
for
d
in
deltas
if
d
.
tool_calls
for
tc
in
d
.
tool_calls
]
assert
tool_calls
[
0
].
function
.
name
==
"final_answer"
args_str
=
""
.
join
(
tc
.
function
.
arguments
or
""
for
tc
in
tool_calls
)
assert
json
.
loads
(
args_str
)
==
{
"trigger"
:
True
}
@
pytest
.
mark
.
parametrize
(
"stream_interval"
,
[
2
,
3
,
5
])
def
test_hermes_streaming_just_forward_text_with_stream_interval
(
qwen_tokenizer
:
TokenizerLike
,
any_chat_request
:
ChatCompletionRequest
,
stream_interval
:
int
,
)
->
None
:
"""Plain text with no tool calls must be fully forwarded."""
text
=
"This is plain text with no tool calling involved."
parser
=
Hermes2ProToolParser
(
qwen_tokenizer
)
deltas
=
_simulate_streaming
(
qwen_tokenizer
,
parser
,
any_chat_request
,
text
,
stream_interval
)
for
d
in
deltas
:
assert
not
d
.
tool_calls
assert
""
.
join
(
d
.
content
for
d
in
deltas
)
==
text
def
test_hermes_parser_non_streaming_no_tool_call
(
def
test_hermes_parser_non_streaming_no_tool_call
(
hermes_parser
:
ToolParser
,
hermes_parser
:
ToolParser
,
any_chat_request
:
ChatCompletionRequest
,
any_chat_request
:
ChatCompletionRequest
,
...
@@ -218,3 +387,28 @@ def test_hermes_parser_non_streaming_tool_call_invalid_json(
...
@@ -218,3 +387,28 @@ def test_hermes_parser_non_streaming_tool_call_invalid_json(
assert
tool_call
is
not
None
assert
tool_call
is
not
None
assert
not
tool_call
.
tools_called
assert
not
tool_call
.
tools_called
def
test_hermes_streaming_content_and_tool_call_in_single_chunk
(
qwen_tokenizer
:
TokenizerLike
,
any_chat_request
:
ChatCompletionRequest
,
)
->
None
:
"""Content + complete tool call in one chunk must both be emitted."""
text
=
'Hi!<tool_call>{"name": "f", "arguments": {"x": 1}}</tool_call>'
# Use a stream_interval large enough to guarantee a single chunk.
parser
=
Hermes2ProToolParser
(
qwen_tokenizer
)
deltas
=
_simulate_streaming
(
qwen_tokenizer
,
parser
,
any_chat_request
,
text
,
stream_interval
=
9999
,
)
content_parts
=
[
d
.
content
for
d
in
deltas
if
d
.
content
]
tool_parts
=
[
tc
for
d
in
deltas
if
d
.
tool_calls
for
tc
in
d
.
tool_calls
]
assert
""
.
join
(
content_parts
)
==
"Hi!"
assert
tool_parts
[
0
].
function
.
name
==
"f"
args_str
=
""
.
join
(
tc
.
function
.
arguments
or
""
for
tc
in
tool_parts
)
assert
json
.
loads
(
args_str
)
==
{
"x"
:
1
}
vllm/tool_parsers/hermes_tool_parser.py
View file @
aee4c146
...
@@ -4,9 +4,7 @@
...
@@ -4,9 +4,7 @@
import
json
import
json
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
import
partial_json_parser
import
regex
as
re
import
regex
as
re
from
partial_json_parser.core.options
import
Allow
from
vllm.entrypoints.chat_utils
import
make_tool_call_id
from
vllm.entrypoints.chat_utils
import
make_tool_call_id
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
...
@@ -31,6 +29,27 @@ from vllm.utils.mistral import is_mistral_tokenizer
...
@@ -31,6 +29,27 @@ from vllm.utils.mistral import is_mistral_tokenizer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
_partial_tag_overlap
(
text
:
str
,
tag
:
str
)
->
int
:
"""Length of the longest prefix of `tag` that matches a suffix of `text`.
E.g. text ending in "<tool_" returns 6 when tag is "<tool_call>".
Returns 0 if there is no overlap.
"""
max_check
=
min
(
len
(
tag
)
-
1
,
len
(
text
))
for
k
in
range
(
max_check
,
0
,
-
1
):
if
text
.
endswith
(
tag
[:
k
]):
return
k
return
0
def
_is_valid_json
(
text
:
str
)
->
bool
:
try
:
json
.
loads
(
text
)
return
True
except
(
json
.
JSONDecodeError
,
ValueError
):
return
False
class
Hermes2ProToolParser
(
ToolParser
):
class
Hermes2ProToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
tools
:
list
[
Tool
]
|
None
=
None
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
tools
:
list
[
Tool
]
|
None
=
None
):
super
().
__init__
(
tokenizer
,
tools
)
super
().
__init__
(
tokenizer
,
tools
)
...
@@ -39,13 +58,6 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -39,13 +58,6 @@ class Hermes2ProToolParser(ToolParser):
logger
.
error
(
"Detected Mistral tokenizer when using a Hermes model"
)
logger
.
error
(
"Detected Mistral tokenizer when using a Hermes model"
)
self
.
model_tokenizer
=
tokenizer
.
tokenizer
self
.
model_tokenizer
=
tokenizer
.
tokenizer
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
# map what has been streamed for each tool so far to a list
self
.
tool_call_start_token
:
str
=
"<tool_call>"
self
.
tool_call_start_token
:
str
=
"<tool_call>"
self
.
tool_call_end_token
:
str
=
"</tool_call>"
self
.
tool_call_end_token
:
str
=
"</tool_call>"
...
@@ -61,57 +73,9 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -61,57 +73,9 @@ class Hermes2ProToolParser(ToolParser):
"The model tokenizer must be passed to the ToolParser "
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
"constructor during construction."
)
)
self
.
tool_call_start_token_ids
=
self
.
model_tokenizer
.
encode
(
self
.
tool_call_start_token
,
add_special_tokens
=
False
)
self
.
tool_call_end_token_ids
=
self
.
model_tokenizer
.
encode
(
self
.
tool_call_end_token
,
add_special_tokens
=
False
)
self
.
tool_call_start_token_array
=
[
# Streaming state: what has been sent to the client.
self
.
model_tokenizer
.
decode
([
token_id
])
self
.
_sent_content_idx
:
int
=
0
for
token_id
in
self
.
tool_call_start_token_ids
]
self
.
tool_call_end_token_array
=
[
self
.
model_tokenizer
.
decode
([
token_id
])
for
token_id
in
self
.
tool_call_end_token_ids
]
self
.
buffered_delta_text
=
""
# Very simple idea: when encountering tokens like <, tool, _call, >,
# <, /, tool, _call, >, store them in a buffer.
# When the last token is encountered, empty the buffer and return it.
# If a token appears in an incorrect sequence while storing in the buffer,
# return the preceding buffer along with the token.
def
tool_call_delta_buffer
(
self
,
delta_text
:
str
):
# If the sequence of tool_call_start or tool_call_end tokens is not yet
# complete, fill the buffer with the token and return "".
if
(
delta_text
in
self
.
tool_call_start_token_array
or
delta_text
in
self
.
tool_call_end_token_array
):
# If delta_text is the last token of tool_call_start_token or
# tool_call_end_token, empty the buffer and return
# the buffered text + delta_text.
if
(
delta_text
==
self
.
tool_call_start_token_array
[
-
1
]
or
delta_text
==
self
.
tool_call_end_token_array
[
-
1
]
):
buffered_text
=
self
.
buffered_delta_text
self
.
buffered_delta_text
=
""
return
buffered_text
+
delta_text
else
:
self
.
buffered_delta_text
=
self
.
buffered_delta_text
+
delta_text
return
""
else
:
if
self
.
buffered_delta_text
:
buffered_text
=
self
.
buffered_delta_text
self
.
buffered_delta_text
=
""
return
buffered_text
+
delta_text
else
:
return
delta_text
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
)
->
ChatCompletionRequest
:
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
)
->
ChatCompletionRequest
:
request
=
super
().
adjust_request
(
request
)
request
=
super
().
adjust_request
(
request
)
...
@@ -174,6 +138,88 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -174,6 +138,88 @@ class Hermes2ProToolParser(ToolParser):
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
)
def
_extract_content
(
self
,
current_text
:
str
)
->
str
|
None
:
"""Return unsent non-tool-call text, or None.
Holds back any suffix that could be a partial <tool_call> tag.
"""
if
self
.
tool_call_start_token
not
in
current_text
:
overlap_length
=
_partial_tag_overlap
(
current_text
,
self
.
tool_call_start_token
)
sendable_idx
=
len
(
current_text
)
-
overlap_length
else
:
sendable_idx
=
current_text
.
index
(
self
.
tool_call_start_token
)
if
sendable_idx
>
self
.
_sent_content_idx
:
content
=
current_text
[
self
.
_sent_content_idx
:
sendable_idx
]
self
.
_sent_content_idx
=
sendable_idx
return
content
return
None
def
_extract_tool_call_jsons
(
self
,
text
:
str
)
->
list
[
tuple
[
str
,
bool
]]:
"""Extract (json_text, is_complete) for each <tool_call> region."""
results
:
list
[
tuple
[
str
,
bool
]]
=
[]
pos
=
0
while
True
:
start
=
text
.
find
(
self
.
tool_call_start_token
,
pos
)
if
start
==
-
1
:
break
json_start
=
start
+
len
(
self
.
tool_call_start_token
)
json_end
=
text
.
find
(
self
.
tool_call_end_token
,
json_start
)
if
json_end
!=
-
1
:
results
.
append
((
text
[
json_start
:
json_end
].
strip
(),
True
))
pos
=
json_end
+
len
(
self
.
tool_call_end_token
)
else
:
raw
=
text
[
json_start
:]
# Strip partial </tool_call> suffix if present.
overlap
=
_partial_tag_overlap
(
raw
,
self
.
tool_call_end_token
)
if
overlap
:
raw
=
raw
[:
-
overlap
]
tc_json
=
raw
.
strip
()
# Valid JSON without closing tag = complete body,
# tag tokens just haven't arrived yet.
is_complete
=
_is_valid_json
(
tc_json
)
if
tc_json
else
False
results
.
append
((
tc_json
,
is_complete
))
break
return
results
@
staticmethod
def
_extract_tool_name
(
tc_json
:
str
)
->
str
|
None
:
"""Extract tool name, or None if the name isn't complete yet."""
match
=
re
.
search
(
r
'"name"\s*:\s*"([^"]+)"'
,
tc_json
)
return
match
.
group
(
1
)
if
match
else
None
@
staticmethod
def
_extract_tool_args
(
tc_json
:
str
,
is_complete
:
bool
)
->
str
|
None
:
"""Extract tool arguments from the tool call JSON.
Given {"name": "f", "arguments": {"x": 1}}, returns '{"x": 1}'.
When is_complete, strips the trailing '}' that closes the outer
object (not the arguments). For partial JSON, returns as-is.
"""
match
=
re
.
search
(
r
'"arguments"\s*:\s*'
,
tc_json
)
if
not
match
:
return
None
raw
=
tc_json
[
match
.
end
()
:]
if
is_complete
:
raw
=
raw
.
rstrip
()
if
raw
.
endswith
(
"}"
):
raw
=
raw
[:
-
1
].
rstrip
()
return
raw
def
_compute_args_diff
(
self
,
index
:
int
,
tc_json
:
str
,
is_complete
:
bool
)
->
str
|
None
:
"""Return new argument text not yet sent for tool `index`, or None."""
args
=
self
.
_extract_tool_args
(
tc_json
,
is_complete
)
if
args
is
None
or
len
(
args
)
<=
len
(
self
.
streamed_args_for_tool
[
index
]):
return
None
diff
=
args
[
len
(
self
.
streamed_args_for_tool
[
index
])
:]
self
.
streamed_args_for_tool
[
index
]
=
args
self
.
prev_tool_call_arr
[
index
][
"arguments"
]
=
args
return
diff
def
extract_tool_calls_streaming
(
def
extract_tool_calls_streaming
(
self
,
self
,
previous_text
:
str
,
previous_text
:
str
,
...
@@ -184,321 +230,64 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -184,321 +230,64 @@ class Hermes2ProToolParser(ToolParser):
delta_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
)
->
DeltaMessage
|
None
:
# 1. All tokens are parsed based on _text, not token_ids.
"""Incrementally stream tool call deltas from accumulated output.
# 2. All incoming text data is processed by the tool_call_delta_buffer
# function for buffering before being used for parsing.
delta_text
=
self
.
tool_call_delta_buffer
(
delta_text
)
# If the last characters of previous_text
# match self.buffered_delta_text, remove only the matching part.
if
(
len
(
previous_text
)
>=
len
(
self
.
buffered_delta_text
)
and
previous_text
[
-
len
(
self
.
buffered_delta_text
)
:]
==
self
.
buffered_delta_text
):
previous_text
=
previous_text
[:
-
len
(
self
.
buffered_delta_text
)]
current_text
=
previous_text
+
delta_text
logger
.
debug
(
"delta_text: %s"
,
delta_text
)
logger
.
debug
(
"delta_token_ids: %s"
,
delta_token_ids
)
# check to see if we should be streaming a tool call - is there a
if
self
.
tool_call_start_token
not
in
current_text
:
logger
.
debug
(
"No tool call tokens found!"
)
return
DeltaMessage
(
content
=
delta_text
)
try
:
On each invocation, re-parses the full ``current_text`` to find
# figure out where we are in the parsing by counting tool call
``<tool_call>`` regions, then diffs against previously sent state
# start & end tags
to emit only new content, tool names, or argument fragments.
prev_tool_start_count
=
previous_text
.
count
(
self
.
tool_call_start_token
)
prev_tool_end_count
=
previous_text
.
count
(
self
.
tool_call_end_token
)
cur_tool_start_count
=
current_text
.
count
(
self
.
tool_call_start_token
)
cur_tool_end_count
=
current_text
.
count
(
self
.
tool_call_end_token
)
tool_call_portion
=
None
text_portion
=
None
# case: if we're generating text, OR rounding out a tool call
if
(
cur_tool_start_count
==
cur_tool_end_count
and
prev_tool_end_count
==
cur_tool_end_count
and
self
.
tool_call_end_token
not
in
delta_text
):
logger
.
debug
(
"Generating text content! skipping tool parsing."
)
return
DeltaMessage
(
content
=
delta_text
)
if
self
.
tool_call_end_token
in
delta_text
:
logger
.
debug
(
"tool_call_end_token in delta_text"
)
full_text
=
current_text
+
delta_text
tool_call_portion
=
(
full_text
.
split
(
self
.
tool_call_start_token
)[
-
1
]
.
split
(
self
.
tool_call_end_token
)[
0
]
.
rstrip
()
)
delta_text
=
delta_text
.
split
(
self
.
tool_call_end_token
)[
0
].
rstrip
()
text_portion
=
delta_text
.
split
(
self
.
tool_call_end_token
)[
-
1
].
lstrip
()
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
# something with tools with this diff.
# flags for partial JSON parting. exported constants from
# "Allow" are handled via BIT MASK
flags
=
Allow
.
ALL
if
self
.
current_tool_name_sent
else
Allow
.
ALL
&
~
Allow
.
STR
# case -- we're starting a new tool call
if
(
cur_tool_start_count
>
cur_tool_end_count
and
cur_tool_start_count
>
prev_tool_start_count
):
if
len
(
delta_token_ids
)
>
1
:
tool_call_portion
=
current_text
.
split
(
self
.
tool_call_start_token
)[
-
1
]
else
:
tool_call_portion
=
None
delta
=
None
text_portion
=
None
# set cursors and state appropriately
self
.
current_tool_id
+=
1
self
.
current_tool_name_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"Starting on a new tool %s"
,
self
.
current_tool_id
)
# case -- we're updating an existing tool call
elif
(
cur_tool_start_count
>
cur_tool_end_count
and
cur_tool_start_count
==
prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion
=
current_text
.
split
(
self
.
tool_call_start_token
)[
-
1
]
text_portion
=
None
# case -- the current tool call is being closed.
elif
(
cur_tool_start_count
==
cur_tool_end_count
and
cur_tool_end_count
>=
prev_tool_end_count
):
if
self
.
prev_tool_call_arr
is
None
or
len
(
self
.
prev_tool_call_arr
)
==
0
:
logger
.
debug
(
"attempting to close tool call, but no tool call"
)
return
None
diff
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
if
diff
:
diff
=
(
diff
.
encode
(
"utf-8"
).
decode
(
"unicode_escape"
)
if
diff
is
str
else
diff
)
if
'"}'
not
in
delta_text
:
return
None
end_loc
=
delta_text
.
rindex
(
'"}'
)
diff
=
delta_text
[:
end_loc
]
+
'"}'
logger
.
debug
(
"Finishing tool and found diff that had not "
"been streamed yet: %s"
,
diff
,
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
diff
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
),
)
]
)
# case -- otherwise we're just generating text
Returns a ``DeltaMessage`` containing either plain content (for
else
:
text preceding any tool call) or one or more ``DeltaToolCall``
text
=
delta_text
.
replace
(
self
.
tool_call_start_token
,
""
)
entries, or ``None`` if there is nothing new to send yet."""
text
=
text
.
replace
(
self
.
tool_call_end_token
,
""
)
try
:
delta
=
DeltaMessage
(
tool_calls
=
[],
content
=
text
)
# Extract any content before tool calls.
return
delta
content
=
self
.
_extract_content
(
current_text
)
tool_call_jsons
=
self
.
_extract_tool_call_jsons
(
current_text
)
try
:
tool_call_deltas
:
list
[
DeltaToolCall
]
=
[]
current_tool_call
=
(
partial_json_parser
.
loads
(
tool_call_portion
or
"{}"
,
flags
)
for
i
,
(
tc_json
,
is_complete
)
in
enumerate
(
tool_call_jsons
):
if
tool_call_portion
if
i
>=
len
(
self
.
prev_tool_call_arr
):
else
None
self
.
prev_tool_call_arr
.
append
({})
)
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"Parsed tool call %s"
,
current_tool_call
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
# Stream back tool name.
logger
.
debug
(
"not enough tokens to parse into JSON yet"
)
if
"name"
not
in
self
.
prev_tool_call_arr
[
i
]:
return
None
name
=
self
.
_extract_tool_name
(
tc_json
)
except
json
.
decoder
.
JSONDecodeError
:
if
not
name
:
logger
.
debug
(
"unable to parse JSON"
)
# Can't skip to tool i+1 if i isn't ready
return
None
break
self
.
prev_tool_call_arr
[
i
][
"name"
]
=
name
if
current_tool_call
is
None
:
tool_call_deltas
.
append
(
return
None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if
not
self
.
current_tool_name_sent
:
function_name
:
str
|
None
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
self
.
current_tool_name_sent
=
True
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
make_tool_call_id
(),
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
),
)
]
)
else
:
return
None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if
tool_call_portion
is
None
:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta
=
(
DeltaMessage
(
content
=
delta_text
)
if
text_portion
is
not
None
else
None
)
return
delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
if
current_tool_call
is
None
:
return
None
logger
.
debug
(
"Trying to parse current tool call with ID %s"
,
self
.
current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
self
.
prev_tool_call_arr
.
append
({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
assert
current_tool_call
is
not
None
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
logger
.
debug
(
"diffing old arguments: %s"
,
prev_arguments
)
logger
.
debug
(
"against new ones: %s"
,
cur_arguments
)
# case -- no arguments have been created yet. skip sending a delta.
if
not
cur_arguments
and
not
prev_arguments
:
logger
.
debug
(
"Skipping text %s - no arguments"
,
delta_text
)
delta
=
None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif
not
cur_arguments
and
prev_arguments
:
logger
.
error
(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta
=
None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif
cur_arguments
and
not
prev_arguments
:
# extract the content after {"name": ..., "arguments":
# directly from tool_call_portion as cur_arguments_json,
# since cur_arguments may differ from the original text
# due to partial JSON parsing
# for example, tool_call_portion =
# {"name": "search", "arguments": {"search_request": {"
# but cur_arguments =
# {"search_request": {}}
function_name
=
current_tool_call
.
get
(
"name"
)
match
=
re
.
search
(
r
'\{"name":\s*"'
+
re
.
escape
(
function_name
)
+
r
'"\s*,\s*"arguments":\s*(.*)'
,
tool_call_portion
.
strip
(),
re
.
DOTALL
,
)
if
match
:
cur_arguments_json
=
match
.
group
(
1
)
else
:
cur_arguments_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
logger
.
debug
(
"finding %s in %s"
,
delta_text
,
cur_arguments_json
)
# get the location where previous args differ from current.
if
delta_text
not
in
cur_arguments_json
:
return
None
args_delta_start_loc
=
cur_arguments_json
.
rindex
(
delta_text
)
+
len
(
delta_text
)
# use that to find the actual delta
arguments_delta
=
cur_arguments_json
[:
args_delta_start_loc
]
logger
.
debug
(
"First tokens in arguments received: %s"
,
arguments_delta
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
DeltaToolCall
(
index
=
self
.
current_tool_id
,
index
=
i
,
function
=
DeltaFunctionCall
(
type
=
"function"
,
arguments
=
arguments_delta
id
=
make_tool_call_id
(),
).
model_dump
(
exclude_none
=
True
),
function
=
DeltaFunctionCall
(
name
=
name
).
model_dump
(
exclude_none
=
True
),
)
)
]
)
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
arguments_delta
# Stream back new tool args by diffing against what was sent.
args_diff
=
self
.
_compute_args_diff
(
i
,
tc_json
,
is_complete
)
# last case -- we have an update to existing arguments.
if
args_diff
:
elif
cur_arguments
and
prev_arguments
:
tool_call_deltas
.
append
(
# judge whether the tool_call_portion is a complete JSON
try
:
json
.
loads
(
tool_call_portion
)
is_complete_json
=
True
except
Exception
:
is_complete_json
=
False
# if the delta_text ends with a '}' and tool_call_portion is a
# complete JSON, then the last '}' does not belong to the
# arguments, so we should trim it off
if
(
isinstance
(
delta_text
,
str
)
and
len
(
delta_text
.
rstrip
())
>=
1
and
delta_text
.
rstrip
()[
-
1
]
==
"}"
and
is_complete_json
):
delta_text
=
delta_text
.
rstrip
()[:
-
1
]
logger
.
debug
(
"got diff %s"
,
delta_text
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
DeltaToolCall
(
index
=
self
.
current_tool_id
,
index
=
i
,
function
=
DeltaFunctionCall
(
arguments
=
delta_text
).
model_dump
(
function
=
DeltaFunctionCall
(
arguments
=
args_diff
).
model_dump
(
exclude_none
=
True
exclude_none
=
True
),
),
)
)
]
)
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
delta_text
# handle saving the state for the current tool into
if
content
or
tool_call_deltas
:
# the "prev" list for use in diffing for the next iteration
return
DeltaMessage
(
assert
isinstance
(
current_tool_call
,
dict
)
content
=
content
,
if
self
.
current_tool_id
==
len
(
self
.
prev_tool_call_arr
)
-
1
:
tool_calls
=
tool_call_deltas
,
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
current_tool_call
)
else
:
self
.
prev_tool_call_arr
.
append
(
current_tool_call
)
return
delta
return
None
except
Exception
:
except
Exception
:
logger
.
exception
(
"Error trying to handle streaming tool call."
)
logger
.
exception
(
"Error trying to handle streaming tool call."
)
return
None
# do not stream a delta. skip this token ID.
return
None
vllm/tool_parsers/longcat_tool_parser.py
View file @
aee4c146
...
@@ -16,23 +16,7 @@ class LongcatFlashToolParser(Hermes2ProToolParser):
...
@@ -16,23 +16,7 @@ class LongcatFlashToolParser(Hermes2ProToolParser):
self
.
tool_call_end_token
:
str
=
"</longcat_tool_call>"
self
.
tool_call_end_token
:
str
=
"</longcat_tool_call>"
self
.
tool_call_regex
=
re
.
compile
(
self
.
tool_call_regex
=
re
.
compile
(
r
"<longcat_tool_call>(.*?)</longcat_tool_call>|<longcat_tool_call>(.*)"
,
r
"<longcat_tool_call>(.*?)</longcat_tool_call>"
r
"|<longcat_tool_call>(.*)"
,
re
.
DOTALL
,
re
.
DOTALL
,
)
)
self
.
tool_call_start_token_ids
=
self
.
model_tokenizer
.
encode
(
self
.
tool_call_start_token
,
add_special_tokens
=
False
)
self
.
tool_call_end_token_ids
=
self
.
model_tokenizer
.
encode
(
self
.
tool_call_end_token
,
add_special_tokens
=
False
)
self
.
tool_call_start_token_array
=
[
self
.
model_tokenizer
.
decode
([
token_id
])
for
token_id
in
self
.
tool_call_start_token_ids
]
self
.
tool_call_end_token_array
=
[
self
.
model_tokenizer
.
decode
([
token_id
])
for
token_id
in
self
.
tool_call_end_token_ids
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment