Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0e39202c
Unverified
Commit
0e39202c
authored
Apr 13, 2026
by
Flora Feng
Committed by
GitHub
Apr 13, 2026
Browse files
[Bugfix] Fix GLM tool parser streaming with MTP or stream interval (#39253)
Signed-off-by:
sfeng33
<
4florafeng@gmail.com
>
parent
9dd5ee01
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
788 additions
and
416 deletions
+788
-416
tests/tool_parsers/test_glm47_moe_tool_parser.py
tests/tool_parsers/test_glm47_moe_tool_parser.py
+14
-17
tests/tool_parsers/test_glm4_moe_tool_parser.py
tests/tool_parsers/test_glm4_moe_tool_parser.py
+509
-103
vllm/tool_parsers/glm4_moe_tool_parser.py
vllm/tool_parsers/glm4_moe_tool_parser.py
+252
-296
vllm/tool_parsers/utils.py
vllm/tool_parsers/utils.py
+13
-0
No files found.
tests/tool_parsers/test_glm47_moe_tool_parser.py
View file @
0e39202c
...
...
@@ -117,28 +117,24 @@ class TestGlm47ExtractToolCalls:
def
_reset
(
parser
):
parser
.
_buffer
=
""
parser
.
_in_tool_call
=
False
parser
.
current_tool_name_sent
=
False
parser
.
_current_tool_name
=
None
parser
.
_pending_key
=
None
parser
.
_streaming_string_value
=
False
parser
.
prev_tool_call_arr
=
[]
parser
.
current_tool_id
=
-
1
parser
.
streamed_args_for_tool
=
[]
parser
.
_tool_call_ids
=
[]
parser
.
_args_started
=
[]
parser
.
_args_closed
=
[]
parser
.
_seen_keys
=
[]
parser
.
_sent_content_idx
=
0
class
TestGlm47Streaming
:
def
test_no_args
(
self
,
glm47_tool_parser
,
mock_request
):
_reset
(
glm47_tool_parser
)
for
chunk
in
[
"<tool_call>"
,
"get_current_date"
,
"</tool_call>"
]:
chunks
=
[
"<tool_call>"
,
"get_current_date"
,
"</tool_call>"
]
current_text
=
""
for
chunk
in
chunks
:
current_text
+=
chunk
glm47_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
""
,
current_text
=
current_text
,
delta_text
=
chunk
,
previous_token_ids
=
[],
current_token_ids
=
[],
...
...
@@ -149,10 +145,7 @@ class TestGlm47Streaming:
def
test_with_args
(
self
,
glm47_tool_parser
,
mock_request
):
_reset
(
glm47_tool_parser
)
# Split chunks so that the incremental string streaming path
# processes the value, its closing tag, and the tool-call closing
# tag in separate calls.
for
chunk
in
[
chunks
=
[
"<tool_call>"
,
"get_weather
\n
"
,
"<arg_key>city</arg_key>"
,
...
...
@@ -160,14 +153,18 @@ class TestGlm47Streaming:
"Beijing"
,
"</arg_value>"
,
"</tool_call>"
,
]:
]
current_text
=
""
for
chunk
in
chunks
:
current_text
+=
chunk
glm47_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
""
,
current_text
=
current_text
,
delta_text
=
chunk
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
assert
glm47_tool_parser
.
prev_tool_call_arr
[
0
][
"arguments"
][
"city"
]
==
"Beijing"
args
=
json
.
loads
(
glm47_tool_parser
.
prev_tool_call_arr
[
0
][
"arguments"
])
assert
args
[
"city"
]
==
"Beijing"
tests/tool_parsers/test_glm4_moe_tool_parser.py
View file @
0e39202c
...
...
@@ -357,81 +357,69 @@ meaningwhile, I will also check the weather in Shanghai.
def
test_streaming_basic_functionality
(
glm4_moe_tool_parser
,
mock_request
):
"""Test basic streaming functionality."""
# Reset streaming state
glm4_moe_tool_parser
.
current_tool_name_sent
=
False
glm4_moe_tool_parser
.
prev_tool_call_arr
=
[]
glm4_moe_tool_parser
.
current_tool_id
=
-
1
glm4_moe_tool_parser
.
streamed_args_for_tool
=
[]
_reset_streaming_state
(
glm4_moe_tool_parser
)
# Test with a simple tool call
current_text
=
"""<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
</tool_call>"""
# Mock token IDs for testing
tool_call_start_id
=
glm4_moe_tool_parser
.
tool_call_start_token_id
or
12345
tool_call_end_id
=
glm4_moe_tool_parser
.
tool_call_end_token_id
or
12346
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
current_text
,
delta_text
=
"</tool_call>"
,
delta_text
=
current_text
,
previous_token_ids
=
[],
current_token_ids
=
[
tool_call_start_id
,
tool_call_end_id
],
delta_token_ids
=
[
tool_call_end_id
],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
# The result behavior depends on the streaming state
# This test mainly ensures no exceptions are thrown
assert
result
is
None
or
hasattr
(
result
,
"tool_calls"
)
or
hasattr
(
result
,
"content"
)
# Should return tool call with name and arguments in one shot
assert
result
is
not
None
assert
result
.
tool_calls
is
not
None
assert
len
(
result
.
tool_calls
)
>=
1
def
test_streaming_no_tool_calls
(
glm4_moe_tool_parser
,
mock_request
):
"""Test streaming when there are no tool calls."""
_reset_streaming_state
(
glm4_moe_tool_parser
)
current_text
=
"This is just regular text without any tool calls."
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"
This is just regular text
"
,
previous_text
=
""
,
current_text
=
current_text
,
delta_text
=
" without any tool calls."
,
delta_text
=
current_text
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
# Should return
the delta text as
content
# Should return content
assert
result
is
not
None
assert
hasattr
(
result
,
"content"
)
assert
result
.
content
==
" without any tool calls."
assert
result
.
content
==
current_text
def
test_streaming_with_content_before_tool_calls
(
glm4_moe_tool_parser
,
mock_request
):
"""Test streaming when there's content before tool calls."""
# Reset streaming state
glm4_moe_tool_parser
.
current_tool_name_sent
=
False
glm4_moe_tool_parser
.
prev_tool_call_arr
=
[]
glm4_moe_tool_parser
.
current_tool_id
=
-
1
glm4_moe_tool_parser
.
streamed_args_for_tool
=
[]
_reset_streaming_state
(
glm4_moe_tool_parser
)
current_text
=
"I will help you get the weather<tool_call>"
current_text
=
"I will help you get the weather
.
<tool_call>"
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"
I will help you
"
,
previous_text
=
""
,
current_text
=
current_text
,
delta_text
=
"get the weather.<tool_call>"
,
delta_text
=
current_text
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
# Should return content
when no
tool
call t
okens are detected
# Should return content
before the <
tool
_
call
>
t
ag
assert
result
is
not
None
assert
hasattr
(
result
,
"content"
)
assert
result
.
content
==
"get the weather."
assert
result
.
content
==
"I will help you get the weather."
def
test_extract_tool_calls_special_characters
(
glm4_moe_tool_parser
,
mock_request
):
...
...
@@ -479,26 +467,19 @@ def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser, mock_requ
def
_reset_streaming_state
(
parser
):
"""Helper to reset parser streaming state."""
parser
.
_buffer
=
""
parser
.
_in_tool_call
=
False
parser
.
current_tool_name_sent
=
False
parser
.
_current_tool_name
=
None
parser
.
_pending_key
=
None
parser
.
_streaming_string_value
=
False
parser
.
prev_tool_call_arr
=
[]
parser
.
current_tool_id
=
-
1
parser
.
streamed_args_for_tool
=
[]
parser
.
_tool_call_ids
=
[]
parser
.
_args_started
=
[]
parser
.
_args_closed
=
[]
parser
.
_seen_keys
=
[]
parser
.
_sent_content_idx
=
0
def
test_streaming_incremental_string_value
(
glm4_moe_tool_parser
,
mock_request
):
"""Test incremental streaming of string argument values."""
_reset_streaming_state
(
glm4_moe_tool_parser
)
# Simulate streaming a tool call ch
aracter by character
# Simulate streaming a tool call ch
unk by chunk
chunks
=
[
"<tool_call>"
,
"get_weather
\n
"
,
...
...
@@ -511,19 +492,20 @@ def test_streaming_incremental_string_value(glm4_moe_tool_parser, mock_request):
]
collected_fragments
=
[]
current_text
=
""
for
chunk
in
chunks
:
current_text
+=
chunk
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
""
,
current_text
=
current_text
,
delta_text
=
chunk
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
if
result
is
not
None
and
hasattr
(
result
,
"tool_calls"
)
and
result
.
tool_calls
:
if
result
is
not
None
and
result
.
tool_calls
:
for
tc
in
result
.
tool_calls
:
if
hasattr
(
tc
,
"function"
)
and
tc
.
function
:
func
=
tc
.
function
if
isinstance
(
func
,
dict
):
if
func
.
get
(
"arguments"
):
...
...
@@ -547,11 +529,11 @@ def test_streaming_empty_tool_call(glm4_moe_tool_parser, mock_request):
"""Test that empty tool calls don't cause infinite loops."""
_reset_streaming_state
(
glm4_moe_tool_parser
)
# Empty
tool
call
should be handled gracefully
current_text
=
"<
tool
_
call
></tool_call>"
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
""
,
delta_text
=
"<tool_call></tool_call>"
,
current_text
=
current_text
,
delta_text
=
current_text
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
...
...
@@ -561,60 +543,52 @@ def test_streaming_empty_tool_call(glm4_moe_tool_parser, mock_request):
# Should not hang and should return something (None or content)
# The key is that this completes without hanging
assert
result
is
None
or
hasattr
(
result
,
"content"
)
or
hasattr
(
result
,
"tool_calls"
)
# State should be properly reset
assert
glm4_moe_tool_parser
.
current_tool_id
==
-
1
def
test_streaming_prev_tool_call_arr_updates
(
glm4_moe_tool_parser
,
mock_request
):
"""Test that prev_tool_call_arr
contains parsed dict after tool c
all."""
"""Test that prev_tool_call_arr
is populated increment
all
y
."""
_reset_streaming_state
(
glm4_moe_tool_parser
)
# Stream a complete tool call
name_only
=
{
"name"
:
"get_weather"
,
"arguments"
:
{}}
name_and_args
=
{
"name"
:
"get_weather"
,
"arguments"
:
{
"city"
:
"Beijing"
}}
chunks
=
[
# Delta, expected streamed_args_for_tool, expected prev_tool_call_arr
(
"<tool_call>get_weather
\n
"
,
""
,
name_only
),
(
"<arg_key>city</arg_key>"
,
""
,
name_only
),
(
"<arg_value>Beijing</arg_value>"
,
'{"city": "Beijing"'
,
name_only
),
# Note: arguments are only updated when the tool call is complete.
(
"</tool_call>"
,
'{"city": "Beijing"}'
,
name_and_args
),
"<tool_call>get_weather
\n
"
,
"<arg_key>city</arg_key>"
,
"<arg_value>Beijing</arg_value>"
,
"</tool_call>"
,
]
for
chunk
,
exp_streamed
,
exp_prev_tc
in
chunks
:
current_text
=
""
for
chunk
in
chunks
:
current_text
+=
chunk
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
""
,
current_text
=
current_text
,
delta_text
=
chunk
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
assert
glm4_moe_tool_parser
.
streamed_args_for_tool
[
0
]
==
exp_streamed
assert
glm4_moe_tool_parser
.
prev_tool_call_arr
[
0
]
==
exp_prev_tc
# After the tool call completes, prev_tool_call_arr should
have parsed dict
# After the tool call completes, prev_tool_call_arr should
be populated
assert
len
(
glm4_moe_tool_parser
.
prev_tool_call_arr
)
==
1
tool_entry
=
glm4_moe_tool_parser
.
prev_tool_call_arr
[
0
]
assert
tool_entry
.
get
(
"name"
)
==
"get_weather"
# arguments should be a dict, not a string
args
=
tool_entry
.
get
(
"arguments"
)
assert
isinstance
(
args
,
dict
),
f
"Expected dict, got
{
type
(
args
)
}
"
assert
args
.
get
(
"city"
)
==
"Beijing"
# Test equivalence of prev_tool_call_arr and streamed_args_for_tool
# Simulates logic in chat_completion/serving.py:chat_completion_stream_generator
tool_call_json
=
json
.
dumps
(
tool_entry
.
get
(
"arguments"
,
{}))
streamed_content
=
glm4_moe_tool_parser
.
streamed_args_for_tool
[
0
]
assert
tool_call_json
.
startswith
(
streamed_content
)
# arguments is a JSON string in the re-parse approach
args_str
=
tool_entry
.
get
(
"arguments"
)
assert
isinstance
(
args_str
,
str
),
f
"Expected str, got
{
type
(
args_str
)
}
"
parsed
=
json
.
loads
(
args_str
)
assert
parsed
[
"city"
]
==
"Beijing"
# streamed_args_for_tool should match prev_tool_call_arr arguments
streamed
=
glm4_moe_tool_parser
.
streamed_args_for_tool
[
0
]
assert
streamed
==
args_str
def
test_streaming_multiple_tool_calls_sequential
(
glm4_moe_tool_parser
,
mock_request
):
"""Test streaming multiple sequential tool calls."""
_reset_streaming_state
(
glm4_moe_tool_parser
)
# Stream two tool calls
chunks
=
[
"<tool_call>get_weather
\n
"
,
"<arg_key>city</arg_key>"
,
...
...
@@ -626,10 +600,12 @@ def test_streaming_multiple_tool_calls_sequential(glm4_moe_tool_parser, mock_req
"</tool_call>"
,
]
current_text
=
""
for
chunk
in
chunks
:
current_text
+=
chunk
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
""
,
current_text
=
current_text
,
delta_text
=
chunk
,
previous_token_ids
=
[],
current_token_ids
=
[],
...
...
@@ -639,15 +615,16 @@ def test_streaming_multiple_tool_calls_sequential(glm4_moe_tool_parser, mock_req
# Should have two tool calls in prev_tool_call_arr
assert
len
(
glm4_moe_tool_parser
.
prev_tool_call_arr
)
==
2
assert
glm4_moe_tool_parser
.
prev_tool_call_arr
[
0
][
"arguments"
][
"city"
]
==
"Beijing"
assert
glm4_moe_tool_parser
.
prev_tool_call_arr
[
1
][
"arguments"
][
"city"
]
==
"Shanghai"
args0
=
json
.
loads
(
glm4_moe_tool_parser
.
prev_tool_call_arr
[
0
][
"arguments"
])
args1
=
json
.
loads
(
glm4_moe_tool_parser
.
prev_tool_call_arr
[
1
][
"arguments"
])
assert
args0
[
"city"
]
==
"Beijing"
assert
args1
[
"city"
]
==
"Shanghai"
def
test_streaming_json_escape_in_string
(
glm4_moe_tool_parser
,
mock_request
):
"""Test that special characters in string values are properly escaped."""
_reset_streaming_state
(
glm4_moe_tool_parser
)
# String with characters that need JSON escaping
chunks
=
[
"<tool_call>send_message
\n
"
,
"<arg_key>message</arg_key>"
,
...
...
@@ -655,10 +632,12 @@ def test_streaming_json_escape_in_string(glm4_moe_tool_parser, mock_request):
"</tool_call>"
,
]
current_text
=
""
for
chunk
in
chunks
:
current_text
+=
chunk
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
""
,
current_text
=
current_text
,
delta_text
=
chunk
,
previous_token_ids
=
[],
current_token_ids
=
[],
...
...
@@ -669,10 +648,8 @@ def test_streaming_json_escape_in_string(glm4_moe_tool_parser, mock_request):
# The streamed_args_for_tool should contain valid JSON
assert
len
(
glm4_moe_tool_parser
.
streamed_args_for_tool
)
==
1
args_json
=
glm4_moe_tool_parser
.
streamed_args_for_tool
[
0
]
# Should be parseable as JSON
parsed
=
json
.
loads
(
args_json
)
assert
"message"
in
parsed
# The value should preserve the special characters
assert
'"'
in
parsed
[
"message"
]
or
"world"
in
parsed
[
"message"
]
...
...
@@ -749,25 +726,25 @@ if __name__ == "__main__":
# Count argument fragments
fragment_count
=
0
current_text
=
""
for
chunk
in
chunks
:
current_text
+=
chunk
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
""
,
current_text
=
current_text
,
delta_text
=
chunk
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
request
,
)
if
result
is
not
None
and
hasattr
(
result
,
"tool_calls"
)
and
result
.
tool_calls
:
if
result
is
not
None
and
result
.
tool_calls
:
for
tc
in
result
.
tool_calls
:
if
hasattr
(
tc
,
"function"
)
and
tc
.
function
:
func
=
tc
.
function
args
=
(
func
.
get
(
"arguments"
)
if
isinstance
(
func
,
dict
)
else
getattr
(
func
,
"arguments"
,
None
)
)
if
isinstance
(
func
,
dict
):
args
=
func
.
get
(
"arguments"
)
else
:
args
=
getattr
(
func
,
"arguments"
,
None
)
if
args
:
fragment_count
+=
1
...
...
@@ -927,3 +904,432 @@ def test_unicode_characters_preserved(glm4_moe_tool_parser, mock_request):
parsed_args
=
json
.
loads
(
raw_args
)
assert
parsed_args
[
"greeting"
]
==
"你好世界"
assert
parsed_args
[
"emoji"
]
==
"🎉"
def
test_streaming_multi_token_chunks
(
glm4_moe_tool_parser
,
mock_request
):
"""Test that multi-token chunks (stream_interval > 1) are handled correctly.
With stream_interval > 1 or MTP, multiple XML tags arrive in one delta.
The old buffer-based parser could only return one delta per call, losing
data on the final output. The re-parse approach handles this correctly.
"""
_reset_streaming_state
(
glm4_moe_tool_parser
)
# Simulate stream_interval=3: chunks contain multiple XML tags
chunks
=
[
"<tool_call>get_weather
\n
<arg_key>city</arg_key><arg_value>Bei"
,
"jing</arg_value>"
,
"</tool_call>"
,
]
current_text
=
""
for
chunk
in
chunks
:
current_text
+=
chunk
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
current_text
,
delta_text
=
chunk
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
# All data should be captured despite multi-token chunks
assert
len
(
glm4_moe_tool_parser
.
prev_tool_call_arr
)
==
1
args
=
json
.
loads
(
glm4_moe_tool_parser
.
streamed_args_for_tool
[
0
])
assert
args
[
"city"
]
==
"Beijing"
def
test_streaming_entire_tool_call_at_once
(
glm4_moe_tool_parser
,
mock_request
):
"""Test that a complete tool call arriving in one delta works.
This simulates the extreme MTP case where all tokens arrive at once.
"""
_reset_streaming_state
(
glm4_moe_tool_parser
)
full_text
=
(
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Beijing</arg_value>"
"</tool_call>"
)
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
full_text
,
delta_text
=
full_text
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
# Should emit tool call with complete arguments in one shot
assert
result
is
not
None
assert
result
.
tool_calls
is
not
None
# Verify final state
assert
len
(
glm4_moe_tool_parser
.
prev_tool_call_arr
)
==
1
args
=
json
.
loads
(
glm4_moe_tool_parser
.
streamed_args_for_tool
[
0
])
assert
args
[
"city"
]
==
"Beijing"
def
test_streaming_content_between_tool_calls_multi_token
(
glm4_moe_tool_parser
,
mock_request
):
"""Test content between tool calls with multi-token chunks."""
_reset_streaming_state
(
glm4_moe_tool_parser
)
# Deliver everything at once — worst case for the old buffer parser
full_text
=
(
"I will check.
\n
"
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Beijing</arg_value>"
"</tool_call>"
"
\n
Also Shanghai.
\n
"
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Shanghai</arg_value>"
"</tool_call>"
)
# First call with partial text (content only)
partial
=
"I will check.
\n
"
result1
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
partial
,
delta_text
=
partial
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
assert
result1
is
not
None
assert
result1
.
content
==
"I will check.
\n
"
# Second call with everything
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
full_text
,
delta_text
=
full_text
[
len
(
partial
)
:],
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
mock_request
,
)
# Should have both tool calls
assert
len
(
glm4_moe_tool_parser
.
prev_tool_call_arr
)
==
2
args0
=
json
.
loads
(
glm4_moe_tool_parser
.
prev_tool_call_arr
[
0
][
"arguments"
])
args1
=
json
.
loads
(
glm4_moe_tool_parser
.
prev_tool_call_arr
[
1
][
"arguments"
])
assert
args0
[
"city"
]
==
"Beijing"
assert
args1
[
"city"
]
==
"Shanghai"
def
test_streaming_multi_token_with_multiple_args
(
glm4_moe_tokenizer
):
"""Test multi-token streaming with multiple arguments of mixed types."""
tools
=
[
ChatCompletionToolsParam
(
function
=
FunctionDefinition
(
name
=
"calculate"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"operation"
:
{
"type"
:
"string"
},
"a"
:
{
"type"
:
"number"
},
"b"
:
{
"type"
:
"number"
},
},
},
),
),
]
parser
=
Glm4MoeModelToolParser
(
glm4_moe_tokenizer
,
tools
=
tools
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
tools
,
)
# All arguments arrive in two big chunks (simulates stream_interval=5)
chunks
=
[
"<tool_call>calculate
\n
<arg_key>operation</arg_key><arg_value>add</arg_value><arg_key>a</arg_key>"
,
"<arg_value>42</arg_value><arg_key>b</arg_key><arg_value>3.14</arg_value></tool_call>"
,
]
current_text
=
""
for
chunk
in
chunks
:
current_text
+=
chunk
parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
current_text
,
delta_text
=
chunk
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
request
,
)
args
=
json
.
loads
(
parser
.
streamed_args_for_tool
[
0
])
assert
args
[
"operation"
]
==
"add"
assert
args
[
"a"
]
==
42
assert
args
[
"b"
]
==
3.14
def
_simulate_streaming
(
tokenizer
,
parser
,
request
,
text
,
stream_interval
=
1
):
"""Simulate streaming with a given stream_interval.
Tokens are batched into chunks of ``stream_interval`` tokens,
mimicking how the output processor delivers them.
Returns a list of non-None DeltaMessages.
"""
tokens
=
tokenizer
.
encode
(
text
)
previous_text
=
""
deltas
=
[]
for
i
in
range
(
0
,
len
(
tokens
),
stream_interval
):
chunk_ids
=
tokens
[
i
:
i
+
stream_interval
]
delta_text
=
tokenizer
.
decode
(
chunk_ids
)
current_text
=
previous_text
+
delta_text
delta
=
parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
chunk_ids
,
request
=
request
,
)
previous_text
=
current_text
if
delta
is
not
None
:
deltas
.
append
(
delta
)
return
deltas
def
_collect_from_deltas
(
deltas
):
"""Reconstruct tool call names/args and content from a delta stream."""
tools
:
dict
[
int
,
dict
]
=
{}
content_parts
:
list
[
str
]
=
[]
for
d
in
deltas
:
if
d
.
content
:
content_parts
.
append
(
d
.
content
)
if
d
.
tool_calls
:
for
tc
in
d
.
tool_calls
:
func
=
tc
.
function
if
isinstance
(
func
,
dict
):
name
=
func
.
get
(
"name"
)
args
=
func
.
get
(
"arguments"
)
else
:
name
=
getattr
(
func
,
"name"
,
None
)
args
=
getattr
(
func
,
"arguments"
,
None
)
idx
=
tc
.
index
if
idx
not
in
tools
:
tools
[
idx
]
=
{
"name"
:
None
,
"args_fragments"
:
[]}
if
name
:
tools
[
idx
][
"name"
]
=
name
if
args
:
tools
[
idx
][
"args_fragments"
].
append
(
args
)
return
content_parts
,
tools
@
pytest
.
mark
.
parametrize
(
"stream_interval"
,
[
1
,
2
,
3
,
5
,
8
])
def
test_stream_interval_single_tool_call
(
glm4_moe_tokenizer
,
stream_interval
):
"""Tool call streaming produces correct name + args at any interval."""
tools
=
[
ChatCompletionToolsParam
(
function
=
FunctionDefinition
(
name
=
"get_weather"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
}},
},
),
),
]
parser
=
Glm4MoeModelToolParser
(
glm4_moe_tokenizer
,
tools
=
tools
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
tools
)
text
=
(
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Beijing</arg_value>"
"</tool_call>"
)
deltas
=
_simulate_streaming
(
glm4_moe_tokenizer
,
parser
,
request
,
text
,
stream_interval
)
_
,
tools_found
=
_collect_from_deltas
(
deltas
)
assert
0
in
tools_found
assert
tools_found
[
0
][
"name"
]
==
"get_weather"
args_json
=
""
.
join
(
tools_found
[
0
][
"args_fragments"
])
parsed
=
json
.
loads
(
args_json
)
assert
parsed
==
{
"city"
:
"Beijing"
}
@
pytest
.
mark
.
parametrize
(
"stream_interval"
,
[
1
,
2
,
3
,
5
,
8
])
def
test_stream_interval_multiple_tool_calls
(
glm4_moe_tokenizer
,
stream_interval
):
"""Multiple sequential tool calls with correct indices at any interval."""
tools
=
[
ChatCompletionToolsParam
(
function
=
FunctionDefinition
(
name
=
"get_weather"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
}},
},
),
),
]
parser
=
Glm4MoeModelToolParser
(
glm4_moe_tokenizer
,
tools
=
tools
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
tools
)
text
=
(
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Beijing</arg_value>"
"</tool_call>"
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Shanghai</arg_value>"
"</tool_call>"
)
deltas
=
_simulate_streaming
(
glm4_moe_tokenizer
,
parser
,
request
,
text
,
stream_interval
)
_
,
tools_found
=
_collect_from_deltas
(
deltas
)
assert
0
in
tools_found
and
1
in
tools_found
args0
=
json
.
loads
(
""
.
join
(
tools_found
[
0
][
"args_fragments"
]))
args1
=
json
.
loads
(
""
.
join
(
tools_found
[
1
][
"args_fragments"
]))
assert
args0
==
{
"city"
:
"Beijing"
}
assert
args1
==
{
"city"
:
"Shanghai"
}
@
pytest
.
mark
.
parametrize
(
"stream_interval"
,
[
1
,
2
,
3
,
5
,
8
])
def
test_stream_interval_content_then_tool_call
(
glm4_moe_tokenizer
,
stream_interval
):
"""Content before a tool call is fully emitted before tool deltas."""
tools
=
[
ChatCompletionToolsParam
(
function
=
FunctionDefinition
(
name
=
"get_weather"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
}},
},
),
),
]
parser
=
Glm4MoeModelToolParser
(
glm4_moe_tokenizer
,
tools
=
tools
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
tools
)
text
=
(
"I will check the weather for you.
\n
"
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Beijing</arg_value>"
"</tool_call>"
)
deltas
=
_simulate_streaming
(
glm4_moe_tokenizer
,
parser
,
request
,
text
,
stream_interval
)
content_parts
,
tools_found
=
_collect_from_deltas
(
deltas
)
# Content must be present and precede tool calls
full_content
=
""
.
join
(
content_parts
)
assert
"I will check the weather"
in
full_content
# Tool call must be correct
assert
0
in
tools_found
assert
tools_found
[
0
][
"name"
]
==
"get_weather"
args
=
json
.
loads
(
""
.
join
(
tools_found
[
0
][
"args_fragments"
]))
assert
args
==
{
"city"
:
"Beijing"
}
def
test_stream_interval_extreme_single_chunk
(
glm4_moe_tokenizer
):
"""Extreme MTP: entire output arrives in one chunk (interval=9999)."""
tools
=
[
ChatCompletionToolsParam
(
function
=
FunctionDefinition
(
name
=
"get_weather"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
}},
},
),
),
]
parser
=
Glm4MoeModelToolParser
(
glm4_moe_tokenizer
,
tools
=
tools
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
tools
)
text
=
(
"Here is the weather.
\n
"
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Beijing</arg_value>"
"</tool_call>"
)
deltas
=
_simulate_streaming
(
glm4_moe_tokenizer
,
parser
,
request
,
text
,
stream_interval
=
9999
)
content_parts
,
tools_found
=
_collect_from_deltas
(
deltas
)
assert
"Here is the weather"
in
""
.
join
(
content_parts
)
assert
0
in
tools_found
assert
tools_found
[
0
][
"name"
]
==
"get_weather"
args
=
json
.
loads
(
""
.
join
(
tools_found
[
0
][
"args_fragments"
]))
assert
args
==
{
"city"
:
"Beijing"
}
@
pytest
.
mark
.
parametrize
(
"stream_interval"
,
[
1
,
2
,
5
])
def
test_stream_interval_content_between_tool_calls
(
glm4_moe_tokenizer
,
stream_interval
):
"""Content between tool calls must be emitted, not silently dropped."""
tools
=
[
ChatCompletionToolsParam
(
function
=
FunctionDefinition
(
name
=
"get_weather"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
}},
},
),
),
]
parser
=
Glm4MoeModelToolParser
(
glm4_moe_tokenizer
,
tools
=
tools
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
tools
)
text
=
(
"Checking Beijing.
\n
"
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Beijing</arg_value>"
"</tool_call>"
"
\n
Also Shanghai.
\n
"
"<tool_call>get_weather
\n
"
"<arg_key>city</arg_key>"
"<arg_value>Shanghai</arg_value>"
"</tool_call>"
)
deltas
=
_simulate_streaming
(
glm4_moe_tokenizer
,
parser
,
request
,
text
,
stream_interval
)
content_parts
,
tools_found
=
_collect_from_deltas
(
deltas
)
full_content
=
""
.
join
(
content_parts
)
# Both prefix and inter-tool-call content must appear
assert
"Checking Beijing"
in
full_content
assert
"Also Shanghai"
in
full_content
# Both tool calls must be correct
assert
0
in
tools_found
and
1
in
tools_found
args0
=
json
.
loads
(
""
.
join
(
tools_found
[
0
][
"args_fragments"
]))
args1
=
json
.
loads
(
""
.
join
(
tools_found
[
1
][
"args_fragments"
]))
assert
args0
==
{
"city"
:
"Beijing"
}
assert
args1
==
{
"city"
:
"Shanghai"
}
vllm/tool_parsers/glm4_moe_tool_parser.py
View file @
0e39202c
...
...
@@ -37,6 +37,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
Tool
,
ToolParser
,
)
from
vllm.tool_parsers.utils
import
partial_tag_overlap
logger
=
init_logger
(
__name__
)
...
...
@@ -44,9 +45,9 @@ logger = init_logger(__name__)
class
Glm4MoeModelToolParser
(
ToolParser
):
"""Tool parser for GLM-4 models with incremental string streaming.
This parser emits tool-call deltas incrementally as arguments arrive.
For string-type parameters, content is streamed character-by-character
rather than waiting for the complete </arg_value> tag
.
On every streaming call the parser re-parses ``current_text`` to find
``<tool_call>`` regions, builds the JSON arguments string for each tool
call, and diffs against what was previously sent to emit only new content
.
"""
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
tools
:
list
[
Tool
]
|
None
=
None
):
...
...
@@ -82,17 +83,17 @@ class Glm4MoeModelToolParser(ToolParser):
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_start_token
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_end_token
)
self
.
_buffer
:
str
=
""
# Streaming state for incremental tool-call streaming
self
.
_in_tool_call
:
bool
=
False
self
.
_current_tool_name
:
str
|
None
=
None
self
.
_pending_key
:
str
|
None
=
None
self
.
_streaming_string_value
:
bool
=
False
# Pre-compiled pattern for finding the last <arg_key>...</arg_key>
# before a partial <arg_value> (used in _build_args_json_so_far).
self
.
_arg_key_pattern
=
re
.
compile
(
re
.
escape
(
self
.
arg_key_start
)
+
r
"(.*?)"
+
re
.
escape
(
self
.
arg_key_end
),
re
.
DOTALL
,
)
# Streaming state for re-parse-and-diff approach
self
.
_sent_content_idx
:
int
=
0
self
.
_tool_call_ids
:
list
[
str
]
=
[]
self
.
_args_started
:
list
[
bool
]
=
[]
self
.
_args_closed
:
list
[
bool
]
=
[]
self
.
_seen_keys
:
list
[
set
[
str
]]
=
[]
@
staticmethod
def
_deserialize
(
value
:
str
)
->
Any
:
...
...
@@ -222,306 +223,261 @@ class Glm4MoeModelToolParser(ToolParser):
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
if
not
self
.
_tools_enabled
(
request
):
return
DeltaMessage
(
content
=
delta_text
)
if
delta_text
else
None
self
.
_buffer
+=
delta_text
def
_extract_content
(
self
,
current_text
:
str
)
->
str
|
None
:
"""Return unsent non-tool-call text, or None.
while
True
:
if
not
self
.
_in_tool_call
:
start_idx
=
self
.
_buffer
.
find
(
self
.
tool_call_start_token
)
if
start_idx
==
-
1
:
# Check for partial start token at end of buffer
for
i
in
range
(
1
,
len
(
self
.
tool_call_start_token
)):
if
self
.
_buffer
.
endswith
(
self
.
tool_call_start_token
[:
i
]):
out
=
self
.
_buffer
[:
-
i
]
self
.
_buffer
=
self
.
_buffer
[
-
i
:]
return
DeltaMessage
(
content
=
out
)
if
out
else
None
out
=
self
.
_buffer
self
.
_buffer
=
""
return
DeltaMessage
(
content
=
out
)
if
out
else
None
if
start_idx
>
0
:
out
=
self
.
_buffer
[:
start_idx
]
self
.
_buffer
=
self
.
_buffer
[
start_idx
:]
return
DeltaMessage
(
content
=
out
)
if
out
else
None
self
.
_buffer
=
self
.
_buffer
[
len
(
self
.
tool_call_start_token
)
:]
self
.
_begin_tool_call
()
continue
Collects all text outside ``<tool_call>...</tool_call>`` regions,
including text between consecutive tool calls. Holds back any
suffix that could be a partial ``<tool_call>`` tag.
"""
# Build the "sendable index" — the furthest point we can send
# content up to. We scan through the text collecting segments
# that are outside tool-call regions.
content_segments
:
list
[
str
]
=
[]
pos
=
self
.
_sent_content_idx
while
pos
<
len
(
current_text
):
start
=
current_text
.
find
(
self
.
tool_call_start_token
,
pos
)
if
start
==
-
1
:
# No more tool calls — send up to (len - partial-tag overlap)
tail
=
current_text
[
pos
:]
overlap
=
partial_tag_overlap
(
tail
,
self
.
tool_call_start_token
)
sendable
=
tail
[:
len
(
tail
)
-
overlap
]
if
overlap
else
tail
if
sendable
:
content_segments
.
append
(
sendable
)
pos
=
len
(
current_text
)
-
overlap
break
# Parse tool name first
if
not
self
.
current_tool_name_sent
:
nl
=
self
.
_buffer
.
find
(
"
\n
"
)
ak
=
self
.
_buffer
.
find
(
self
.
arg_key_start
)
end
=
self
.
_buffer
.
find
(
self
.
tool_call_end_token
)
candidates
=
[
i
for
i
in
[
nl
,
ak
,
end
]
if
i
!=
-
1
]
if
not
candidates
:
return
None
cut
=
min
(
candidates
)
tool_name
=
self
.
_buffer
[:
cut
].
strip
()
if
tool_name
==
""
and
cut
==
end
:
# Handle empty tool call like `<tool_call></tool_call>`.
# Consume the tokens and reset state to avoid infinite loop.
self
.
_buffer
=
self
.
_buffer
[
end
+
len
(
self
.
tool_call_end_token
)
:]
self
.
_finish_tool_call
()
self
.
_revert_last_tool_call_state
()
continue
# Text before this <tool_call>
if
start
>
pos
:
content_segments
.
append
(
current_text
[
pos
:
start
])
if
cut
==
nl
:
self
.
_buffer
=
self
.
_buffer
[
nl
+
1
:]
else
:
self
.
_buffer
=
self
.
_buffer
[
cut
:]
self
.
_current_tool_name
=
tool_name
self
.
current_tool_name_sent
=
True
return
self
.
_emit_tool_name_delta
(
tool_name
)
assert
self
.
_current_tool_name
is
not
None
# Handle incremental string value streaming
if
self
.
_streaming_string_value
:
val_end
=
self
.
_buffer
.
find
(
self
.
arg_val_end
)
if
val_end
!=
-
1
:
raw_content
=
self
.
_buffer
[:
val_end
]
self
.
_buffer
=
self
.
_buffer
[
val_end
+
len
(
self
.
arg_val_end
)
:]
self
.
_streaming_string_value
=
False
self
.
_pending_key
=
None
escaped
=
self
.
_json_escape_string_content
(
raw_content
)
frag
=
escaped
+
'"'
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
frag
return
self
.
_emit_tool_args_delta
(
frag
)
# Skip past the </tool_call> (or to end if incomplete)
end
=
current_text
.
find
(
self
.
tool_call_end_token
,
start
)
if
end
!=
-
1
:
pos
=
end
+
len
(
self
.
tool_call_end_token
)
else
:
# Check for partial </arg_value> at end
safe_len
=
len
(
self
.
_buffer
)
for
i
in
range
(
1
,
len
(
self
.
arg_val_end
)):
if
self
.
_buffer
.
endswith
(
self
.
arg_val_end
[:
i
]):
safe_len
=
len
(
self
.
_buffer
)
-
i
# Incomplete tool call — nothing more to send
pos
=
start
break
if
safe_len
>
0
:
to_emit
=
self
.
_buffer
[:
safe_len
]
self
.
_buffer
=
self
.
_buffer
[
safe_len
:]
escaped
=
self
.
_json_escape_string_content
(
to_emit
)
if
escaped
:
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
escaped
return
self
.
_emit_tool_args_delta
(
escaped
)
if
content_segments
:
self
.
_sent_content_idx
=
pos
return
""
.
join
(
content_segments
)
# Even if no content, advance past completed tool-call regions
if
pos
>
self
.
_sent_content_idx
:
self
.
_sent_content_idx
=
pos
return
None
# If we have a pending key, parse its value
if
self
.
_pending_key
is
not
None
:
val_pos
=
self
.
_buffer
.
find
(
self
.
arg_val_start
)
if
val_pos
==
-
1
:
return
None
if
val_pos
>
0
:
self
.
_buffer
=
self
.
_buffer
[
val_pos
:]
key
=
(
self
.
_pending_key
or
""
).
strip
()
def
_extract_tool_call_regions
(
self
,
text
:
str
)
->
list
[
tuple
[
str
,
bool
]]:
"""Extract ``(inner_text, is_complete)`` for each ``<tool_call>`` region."""
results
:
list
[
tuple
[
str
,
bool
]]
=
[]
pos
=
0
while
True
:
start
=
text
.
find
(
self
.
tool_call_start_token
,
pos
)
if
start
==
-
1
:
break
inner_start
=
start
+
len
(
self
.
tool_call_start_token
)
end
=
text
.
find
(
self
.
tool_call_end_token
,
inner_start
)
if
end
!=
-
1
:
results
.
append
((
text
[
inner_start
:
end
],
True
))
pos
=
end
+
len
(
self
.
tool_call_end_token
)
else
:
# Incomplete tool call — strip partial </tool_call> suffix
raw
=
text
[
inner_start
:]
overlap
=
partial_tag_overlap
(
raw
,
self
.
tool_call_end_token
)
if
overlap
:
raw
=
raw
[:
-
overlap
]
results
.
append
((
raw
,
False
))
break
return
results
is_string
=
self
.
_is_string_type
(
self
.
_current_tool_name
,
key
,
self
.
tools
)
def
_extract_tool_name_from_region
(
self
,
inner_text
:
str
)
->
str
|
None
:
"""Extract the tool name from the beginning of a tool-call region.
if
is_string
:
# String type: stream incrementally
self
.
_buffer
=
self
.
_buffer
[
len
(
self
.
arg_val_start
)
:]
The name is everything before the first ``
\\
n`` or ``<arg_key>``.
Returns ``None`` if the name hasn't fully arrived yet.
"""
nl
=
inner_text
.
find
(
"
\n
"
)
ak
=
inner_text
.
find
(
self
.
arg_key_start
)
candidates
=
[
i
for
i
in
[
nl
,
ak
]
if
i
!=
-
1
]
if
not
candidates
:
return
None
cut
=
min
(
candidates
)
name
=
inner_text
[:
cut
].
strip
()
return
name
if
name
else
None
if
key
in
self
.
_seen_keys
[
self
.
current_tool_id
]:
self
.
_pending_key
=
None
continue
def
_build_args_json_so_far
(
self
,
tool_name
:
str
,
inner_text
:
str
,
is_complete
:
bool
,
)
->
str
:
"""Build the JSON arguments string from the XML pairs seen so far.
For complete ``<arg_key>/<arg_value>`` pairs the value is fully
formatted. For the last argument whose ``<arg_value>`` has been
opened but not closed, the partial string content is included
(JSON-escaped, with an opening ``"`` but no closing ``"``).
The closing ``}`` is only appended when ``is_complete`` is True
(i.e. the ``</tool_call>`` tag has arrived).
"""
# Find all complete arg pairs
pairs
=
self
.
func_arg_regex
.
findall
(
inner_text
)
self
.
_seen_keys
[
self
.
current_tool_id
].
add
(
key
)
parts
:
list
[
str
]
=
[]
for
key
,
value
in
pairs
:
key
=
key
.
strip
()
key_json
=
json
.
dumps
(
key
,
ensure_ascii
=
False
)
if
not
self
.
_args_started
[
self
.
current_tool_id
]:
frag
=
"{"
+
key_json
+
': "'
self
.
_args_started
[
self
.
current_tool_id
]
=
True
if
self
.
_is_string_type
(
tool_name
,
key
,
self
.
tools
):
# Don't strip string values — whitespace is significant
# and must match the partial-value path for diffing.
val_json
=
json
.
dumps
(
value
,
ensure_ascii
=
False
)
else
:
frag
=
", "
+
key_json
+
': "'
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
frag
self
.
_streaming_string_value
=
True
return
self
.
_emit_tool_args_delta
(
frag
)
else
:
# Non-string type: wait for complete value
val_end
=
self
.
_buffer
.
find
(
self
.
arg_val_end
)
if
val_end
==
-
1
:
return
None
raw_val
=
self
.
_buffer
[
len
(
self
.
arg_val_start
)
:
val_end
].
strip
()
self
.
_buffer
=
self
.
_buffer
[
val_end
+
len
(
self
.
arg_val_end
)
:]
self
.
_pending_key
=
None
frag_or_none
=
self
.
_append_arg_fragment
(
key
=
key
,
raw_val
=
raw_val
)
if
frag_or_none
:
return
self
.
_emit_tool_args_delta
(
frag_or_none
)
continue
# Parse next arg or close
end_pos
=
self
.
_buffer
.
find
(
self
.
tool_call_end_token
)
key_pos
=
self
.
_buffer
.
find
(
self
.
arg_key_start
)
if
end_pos
!=
-
1
and
(
key_pos
==
-
1
or
end_pos
<
key_pos
):
self
.
_buffer
=
self
.
_buffer
[
end_pos
+
len
(
self
.
tool_call_end_token
)
:]
frag_or_none
=
self
.
_close_args_if_needed
()
# Finalize prev_tool_call_arr with complete parsed arguments
if
self
.
_current_tool_name
:
try
:
full_args_str
=
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
args_dict
=
json
.
loads
(
full_args_str
)
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
{
"name"
:
self
.
_current_tool_name
,
"arguments"
:
args_dict
,
}
except
(
json
.
JSONDecodeError
,
IndexError
)
as
e
:
logger
.
warning
(
"Failed to finalize tool call state for tool %d: %s"
,
self
.
current_tool_id
,
e
,
val_json
=
json
.
dumps
(
self
.
_deserialize
(
value
.
strip
()),
ensure_ascii
=
False
)
self
.
_finish_tool_call
()
return
(
self
.
_emit_tool_args_delta
(
frag_or_none
)
if
frag_or_none
else
None
parts
.
append
(
f
"
{
key_json
}
:
{
val_json
}
"
)
# Check for a partial (incomplete) arg value
# Find the last <arg_value> that isn't closed
last_val_start
=
inner_text
.
rfind
(
self
.
arg_val_start
)
last_val_end
=
inner_text
.
rfind
(
self
.
arg_val_end
)
has_partial_value
=
last_val_start
!=
-
1
and
(
last_val_end
==
-
1
or
last_val_end
<
last_val_start
)
if
key_pos
==
-
1
:
return
None
if
key_pos
>
0
:
self
.
_buffer
=
self
.
_buffer
[
key_pos
:]
key_end
=
self
.
_buffer
.
find
(
self
.
arg_key_end
)
if
key_end
==
-
1
:
if
has_partial_value
:
# Find the key for this partial value
# Look for the last <arg_key>...</arg_key> before this <arg_value>
last_key_match
=
None
for
m
in
self
.
_arg_key_pattern
.
finditer
(
inner_text
[:
last_val_start
]):
last_key_match
=
m
if
last_key_match
:
partial_key
=
last_key_match
.
group
(
1
).
strip
()
partial_content_start
=
last_val_start
+
len
(
self
.
arg_val_start
)
partial_content
=
inner_text
[
partial_content_start
:]
# Hold back any partial </arg_value> suffix
overlap
=
partial_tag_overlap
(
partial_content
,
self
.
arg_val_end
)
if
overlap
:
partial_content
=
partial_content
[:
-
overlap
]
key_json
=
json
.
dumps
(
partial_key
,
ensure_ascii
=
False
)
if
is_complete
:
# Tool call finished but </arg_value> is missing
# (malformed output). Treat partial as complete value
# so the diff naturally closes any open quotes.
if
self
.
_is_string_type
(
tool_name
,
partial_key
,
self
.
tools
):
val_json
=
json
.
dumps
(
partial_content
,
ensure_ascii
=
False
)
else
:
val_json
=
json
.
dumps
(
self
.
_deserialize
(
partial_content
.
strip
()),
ensure_ascii
=
False
,
)
parts
.
append
(
f
"
{
key_json
}
:
{
val_json
}
"
)
elif
self
.
_is_string_type
(
tool_name
,
partial_key
,
self
.
tools
):
escaped
=
self
.
_json_escape_string_content
(
partial_content
)
# Open quote but no close — more content may arrive
parts
.
append
(
f
'
{
key_json
}
: "
{
escaped
}
'
)
else
:
# Non-string partial: include raw content, no wrapping
parts
.
append
(
f
"
{
key_json
}
:
{
partial_content
}
"
)
if
not
parts
:
return
"{}"
if
is_complete
else
""
joined
=
"{"
+
", "
.
join
(
parts
)
if
is_complete
:
joined
+=
"}"
return
joined
def
_compute_args_diff
(
self
,
index
:
int
,
args_so_far
:
str
)
->
str
|
None
:
"""Return new argument text not yet sent for tool *index*, or None."""
if
not
args_so_far
or
len
(
args_so_far
)
<=
len
(
self
.
streamed_args_for_tool
[
index
]
):
return
None
key
=
self
.
_buffer
[
len
(
self
.
arg_key_start
)
:
key_end
]
self
.
_buffer
=
self
.
_buffer
[
key_end
+
len
(
self
.
arg_key_end
)
:]
self
.
_pending_key
=
key
continue
def
_ensure_tool_state
(
self
)
->
None
:
while
len
(
self
.
_tool_call_ids
)
<=
self
.
current_tool_id
:
diff
=
args_so_far
[
len
(
self
.
streamed_args_for_tool
[
index
])
:]
self
.
streamed_args_for_tool
[
index
]
=
args_so_far
self
.
prev_tool_call_arr
[
index
][
"arguments"
]
=
args_so_far
return
diff
def
_ensure_tool_state_for
(
self
,
index
:
int
)
->
None
:
"""Grow state arrays so that *index* is valid."""
while
len
(
self
.
_tool_call_ids
)
<=
index
:
self
.
_tool_call_ids
.
append
(
make_tool_call_id
(
id_type
=
"random"
,
func_name
=
None
,
idx
=
None
)
)
while
len
(
self
.
streamed_args_for_tool
)
<=
self
.
current_tool_id
:
while
len
(
self
.
streamed_args_for_tool
)
<=
index
:
self
.
streamed_args_for_tool
.
append
(
""
)
while
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
while
len
(
self
.
prev_tool_call_arr
)
<=
index
:
self
.
prev_tool_call_arr
.
append
({})
while
len
(
self
.
_args_started
)
<=
self
.
current_tool_id
:
self
.
_args_started
.
append
(
False
)
while
len
(
self
.
_args_closed
)
<=
self
.
current_tool_id
:
self
.
_args_closed
.
append
(
False
)
while
len
(
self
.
_seen_keys
)
<=
self
.
current_tool_id
:
self
.
_seen_keys
.
append
(
set
())
def
_begin_tool_call
(
self
)
->
None
:
if
self
.
current_tool_id
==
-
1
:
self
.
current_tool_id
=
0
else
:
self
.
current_tool_id
+=
1
self
.
_ensure_tool_state
()
self
.
current_tool_name_sent
=
False
self
.
_current_tool_name
=
None
self
.
_pending_key
=
None
self
.
_streaming_string_value
=
False
self
.
_in_tool_call
=
True
def
_finish_tool_call
(
self
)
->
None
:
self
.
_in_tool_call
=
False
self
.
_current_tool_name
=
None
self
.
_pending_key
=
None
self
.
_streaming_string_value
=
False
def
_revert_last_tool_call_state
(
self
)
->
None
:
"""Revert the state allocation for the last tool call."""
if
self
.
current_tool_id
<
0
:
return
self
.
_tool_call_ids
.
pop
()
self
.
streamed_args_for_tool
.
pop
()
self
.
prev_tool_call_arr
.
pop
()
self
.
_args_started
.
pop
()
self
.
_args_closed
.
pop
()
self
.
_seen_keys
.
pop
()
self
.
current_tool_id
-=
1
def
_emit_tool_name_delta
(
self
,
tool_name
:
str
)
->
DeltaMessage
:
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
{
"name"
:
self
.
_current_tool_name
,
"arguments"
:
{},
}
return
DeltaMessage
(
tool_calls
=
[
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
if
not
self
.
_tools_enabled
(
request
):
return
DeltaMessage
(
content
=
delta_text
)
if
delta_text
else
None
content
=
self
.
_extract_content
(
current_text
)
regions
=
self
.
_extract_tool_call_regions
(
current_text
)
tool_call_deltas
:
list
[
DeltaToolCall
]
=
[]
for
i
,
(
inner_text
,
is_complete
)
in
enumerate
(
regions
):
self
.
_ensure_tool_state_for
(
i
)
# Extract tool name
tool_name
=
self
.
_extract_tool_name_from_region
(
inner_text
)
if
not
tool_name
:
break
# Emit tool name (once per tool call)
if
"name"
not
in
self
.
prev_tool_call_arr
[
i
]:
self
.
prev_tool_call_arr
[
i
][
"name"
]
=
tool_name
tool_call_deltas
.
append
(
DeltaToolCall
(
index
=
self
.
current_tool_id
,
id
=
self
.
_tool_call_ids
[
self
.
current_tool_id
],
index
=
i
,
id
=
self
.
_tool_call_ids
[
i
],
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
tool_name
,
arguments
=
""
,
).
model_dump
(
exclude_none
=
True
),
)
]
)
def
_emit_tool_args_delta
(
self
,
fragment
:
str
)
->
DeltaMessage
:
return
DeltaMessage
(
tool_calls
=
[
# Build args JSON so far, diff, emit
args_so_far
=
self
.
_build_args_json_so_far
(
tool_name
,
inner_text
,
is_complete
)
diff
=
self
.
_compute_args_diff
(
i
,
args_so_far
)
if
diff
:
tool_call_deltas
.
append
(
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
fragment
).
model_dump
(
index
=
i
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
),
)
]
)
def
_append_arg_fragment
(
self
,
*
,
key
:
str
,
raw_val
:
str
,
)
->
str
|
None
:
key
=
key
.
strip
()
if
not
key
:
return
None
if
key
in
self
.
_seen_keys
[
self
.
current_tool_id
]:
return
None
# This function is only called for non-string types (already checked
# by _is_string_type in the caller), so we always deserialize.
val_obj
:
Any
=
self
.
_deserialize
(
raw_val
)
# Update current_tool_id for serving layer compatibility
if
regions
:
self
.
current_tool_id
=
len
(
regions
)
-
1
key_json
=
json
.
dumps
(
key
,
ensure_ascii
=
False
)
val_json
=
json
.
dumps
(
val_obj
,
ensure_ascii
=
False
)
if
not
self
.
_args_started
[
self
.
current_tool_id
]:
fragment
=
"{"
+
key_json
+
": "
+
val_json
self
.
_args_started
[
self
.
current_tool_id
]
=
True
else
:
fragment
=
","
+
key_json
+
": "
+
val_json
self
.
_seen_keys
[
self
.
current_tool_id
].
add
(
key
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
fragment
return
fragment
def
_close_args_if_needed
(
self
)
->
str
|
None
:
if
self
.
_args_closed
[
self
.
current_tool_id
]:
if
content
or
tool_call_deltas
:
return
DeltaMessage
(
content
=
content
,
tool_calls
=
tool_call_deltas
,
)
return
None
self
.
_args_closed
[
self
.
current_tool_id
]
=
True
if
not
self
.
_args_started
[
self
.
current_tool_id
]:
fragment
=
"{}"
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
=
fragment
else
:
fragment
=
"}"
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
fragment
return
fragment
vllm/tool_parsers/utils.py
View file @
0e39202c
...
...
@@ -31,6 +31,19 @@ Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool
logger
=
init_logger
(
__name__
)
def
partial_tag_overlap
(
text
:
str
,
tag
:
str
)
->
int
:
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
E.g. text ending in ``"<tool_"`` returns 6 when tag is ``"<tool_call>"``.
Returns 0 when there is no overlap.
"""
max_check
=
min
(
len
(
tag
)
-
1
,
len
(
text
))
for
k
in
range
(
max_check
,
0
,
-
1
):
if
text
.
endswith
(
tag
[:
k
]):
return
k
return
0
def
find_common_prefix
(
s1
:
str
,
s2
:
str
)
->
str
:
"""
Finds a common prefix that is shared between two strings, if there is one.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment