Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
93269bb4
Unverified
Commit
93269bb4
authored
Jul 28, 2025
by
Yuxuan Zhang
Committed by
GitHub
Jul 28, 2025
Browse files
Fix GLM tool parser (#21668)
Co-authored-by:
Chenhui Zhang
<
zhang.chenhui@outlook.com
>
parent
82acf218
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
113 additions
and
330 deletions
+113
-330
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
+113
-330
No files found.
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
View file @
93269bb4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# code modified from deepseekv3_tool_parser.py
import
ast
import
json
from
collections.abc
import
Sequence
from
typing
import
Union
from
typing
import
Any
,
Optional
,
Union
import
regex
as
re
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionToolsParam
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
...
...
@@ -34,36 +36,13 @@ class Glm4MoeModelToolParser(ToolParser):
self
.
tool_calls_start_token
=
self
.
tool_call_start_token
# Updated regex for the XML-based format
self
.
tool_call_regex
=
re
.
compile
(
r
"<tool_call>\s*"
r
"(?P<function_name>[^\n<]+)\s*"
# 函数名(到换行或 <)
r
"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*"
r
"<arg_value>[^<]*</arg_value>\s*)*)\s*"
r
"</tool_call>"
,
re
.
DOTALL
,
)
# Regex for parsing individual arguments
self
.
arg_regex
=
re
.
compile
(
r
"<arg_key>(?P<key>[^<]+)</arg_key>\s*<arg_value>(?P<value>[^<]*)</arg_value>"
,
re
.
DOTALL
,
)
# Streaming regex
self
.
stream_tool_call_portion_regex
=
re
.
compile
(
r
"(?P<function_name>[^\n<]+)\s*"
r
"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*"
r
"<arg_value>[^<]*</arg_value>\s*)*)"
,
re
.
DOTALL
,
)
# For streaming, we also need a regex to match just the function name
self
.
stream_tool_call_name_regex
=
re
.
compile
(
r
"(?P<function_name>[^\n<]+)"
,
re
.
DOTALL
,
)
self
.
func_call_regex
=
re
.
compile
(
r
"<tool_call>.*?</tool_call>"
,
re
.
DOTALL
)
self
.
func_detail_regex
=
re
.
compile
(
r
"<tool_call>([^\n]*)\n(.*)</tool_call>"
,
re
.
DOTALL
)
self
.
func_arg_regex
=
re
.
compile
(
r
"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
...
...
@@ -72,20 +51,7 @@ class Glm4MoeModelToolParser(ToolParser):
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_start_token
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_end_token
)
def
_parse_arguments
(
self
,
args_text
:
str
)
->
str
:
"""Parse XML-based arguments into JSON format."""
if
not
args_text
or
not
args_text
.
strip
():
return
"{}"
args_dict
=
{}
matches
=
self
.
arg_regex
.
findall
(
args_text
)
for
key
,
value
in
matches
:
args_dict
[
key
.
strip
()]
=
value
.
strip
()
import
json
return
json
.
dumps
(
args_dict
,
ensure_ascii
=
False
)
self
.
_buffer
=
""
def
extract_tool_calls
(
self
,
...
...
@@ -93,52 +59,67 @@ class Glm4MoeModelToolParser(ToolParser):
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
# sanity check; avoid unnecessary processing
if
self
.
tool_calls_start_token
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
_is_string_type
(
tool_name
:
str
,
arg_name
:
str
,
tools
:
Optional
[
list
[
ChatCompletionToolsParam
]])
->
bool
:
if
tools
is
None
:
return
False
for
tool
in
tools
:
if
tool
.
function
.
name
==
tool_name
:
if
tool
.
function
.
parameters
is
None
:
return
False
arg_type
=
tool
.
function
.
parameters
.
get
(
"properties"
,
{}).
get
(
arg_name
,
{}).
get
(
"type"
,
None
)
return
arg_type
==
"string"
logger
.
warning
(
"No tool named '%s'."
,
tool_name
)
return
False
def
_deserialize
(
value
:
str
)
->
Any
:
try
:
# Find all tool calls in the output
function_call_matches
=
self
.
tool_call_regex
.
findall
(
model_output
)
logger
.
debug
(
"function_call_matches: %s"
,
function_call_matches
)
return
json
.
loads
(
value
)
except
Exception
:
pass
if
not
function_call_matches
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
,
)
try
:
return
ast
.
literal_eval
(
value
)
except
Exception
:
pass
return
value
matched_tool_calls
=
self
.
func_call_regex
.
findall
(
model_output
)
logger
.
debug
(
"model_output: %s"
,
model_output
)
try
:
tool_calls
=
[]
for
i
,
match
in
enumerate
(
function_call_matches
):
function_name
,
function_args_xml
=
match
function_name
=
function_name
.
strip
()
# Parse XML arguments to JSON
function_args_json
=
self
.
_parse_arguments
(
function_args_xml
)
for
match
in
matched_tool_calls
:
tc_detail
=
self
.
func_detail_regex
.
search
(
match
)
tc_name
=
tc_detail
.
group
(
1
)
tc_args
=
tc_detail
.
group
(
2
)
pairs
=
self
.
func_arg_regex
.
findall
(
tc_args
)
arg_dct
=
{}
for
key
,
value
in
pairs
:
arg_key
=
key
.
strip
()
arg_val
=
value
.
strip
()
if
not
_is_string_type
(
tc_name
,
arg_key
,
request
.
tools
):
arg_val
=
_deserialize
(
arg_val
)
logger
.
debug
(
"arg_key = %s, arg_val = %s"
,
arg_key
,
arg_val
)
arg_dct
[
arg_key
]
=
arg_val
tool_calls
.
append
(
ToolCall
(
id
=
f
"call_
{
i
}
"
,
type
=
'function'
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
function_args_json
),
))
# Extract content before the first tool call
content
=
model_output
[:
model_output
.
find
(
self
.
tool_calls_start_token
)]
return
ExtractedToolCallInformation
(
tools_called
=
bool
(
tool_calls
),
tool_calls
=
tool_calls
,
content
=
content
.
strip
()
if
content
.
strip
()
else
None
,
)
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
tc_name
,
arguments
=
json
.
dumps
(
arg_dct
))))
except
Exception
:
logger
.
exception
(
"Error in extracting tool call from response."
)
logger
.
exception
(
"Failed to extract tool call spec"
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
else
:
if
len
(
tool_calls
)
>
0
:
content
=
model_output
[:
model_output
.
find
(
self
.
tool_calls_start_token
)]
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
...
...
@@ -153,250 +134,52 @@ class Glm4MoeModelToolParser(ToolParser):
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
Union
[
DeltaMessage
,
None
]:
logger
.
debug
(
"delta_text: %s"
,
delta_text
)
logger
.
debug
(
"delta_token_ids: %s"
,
delta_token_ids
)
# check to see if we should be streaming a tool call - is there a
if
self
.
tool_call_start_token_id
not
in
current_token_ids
:
logger
.
debug
(
"No tool call tokens found!"
)
return
DeltaMessage
(
content
=
delta_text
)
delta_text
=
delta_text
.
replace
(
self
.
tool_calls_start_token
,
""
).
replace
(
self
.
tool_call_end_token
,
""
)
try
:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count
=
previous_token_ids
.
count
(
self
.
tool_call_start_token_id
)
prev_tool_end_count
=
previous_token_ids
.
count
(
self
.
tool_call_end_token_id
)
cur_tool_start_count
=
current_token_ids
.
count
(
self
.
tool_call_start_token_id
)
cur_tool_end_count
=
current_token_ids
.
count
(
self
.
tool_call_end_token_id
)
tool_call_portion
=
None
text_portion
=
None
# case: if we're generating text, OR rounding out a tool call
if
(
cur_tool_start_count
==
cur_tool_end_count
and
prev_tool_end_count
==
cur_tool_end_count
and
self
.
tool_call_end_token
not
in
delta_text
):
logger
.
debug
(
"Generating text content! skipping tool parsing."
)
return
DeltaMessage
(
content
=
delta_text
)
if
self
.
tool_call_end_token
in
delta_text
:
logger
.
debug
(
"tool_call_end_token in delta_text"
)
full_text
=
current_text
+
delta_text
tool_call_portion
=
full_text
.
split
(
self
.
tool_call_start_token
)[
-
1
].
split
(
self
.
tool_call_end_token
)[
0
].
rstrip
()
delta_text
=
delta_text
.
split
(
self
.
tool_call_end_token
)[
0
].
rstrip
()
text_portion
=
delta_text
.
split
(
self
.
tool_call_end_token
)[
-
1
].
lstrip
()
# case -- we're starting a new tool call
if
(
cur_tool_start_count
>
cur_tool_end_count
and
cur_tool_start_count
>
prev_tool_start_count
):
if
len
(
delta_token_ids
)
>
1
:
tool_call_portion
=
current_text
.
split
(
self
.
tool_call_start_token
)[
-
1
]
else
:
tool_call_portion
=
None
delta
=
None
text_portion
=
None
# set cursors and state appropriately
self
.
current_tool_id
+=
1
self
.
current_tool_name_sent
=
False
self
.
_buffer
+=
delta_text
cur_text
=
self
.
_buffer
start_idx
=
cur_text
.
find
(
self
.
tool_call_start_token
)
if
start_idx
==
-
1
:
self
.
_buffer
=
""
if
self
.
current_tool_id
>
0
:
cur_text
=
""
return
DeltaMessage
(
content
=
cur_text
)
logger
.
debug
(
"cur_text = %s"
,
cur_text
)
end_idx
=
cur_text
.
find
(
self
.
tool_call_end_token
)
if
end_idx
!=
-
1
:
if
self
.
current_tool_id
==
-
1
:
self
.
current_tool_id
=
0
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[]
while
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
self
.
prev_tool_call_arr
.
append
({})
while
len
(
self
.
streamed_args_for_tool
)
<=
self
.
current_tool_id
:
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"Starting on a new tool %s"
,
self
.
current_tool_id
)
# case -- we're updating an existing tool call
elif
(
cur_tool_start_count
>
cur_tool_end_count
and
cur_tool_start_count
==
prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion
=
current_text
.
split
(
self
.
tool_call_start_token
)[
-
1
]
text_portion
=
None
extracted_tool_calls
=
self
.
extract_tool_calls
(
cur_text
[:
end_idx
+
len
(
self
.
tool_call_end_token
)],
request
)
# case -- the current tool call is being closed.
elif
(
cur_tool_start_count
==
cur_tool_end_count
and
cur_tool_end_count
>=
prev_tool_end_count
):
if
self
.
prev_tool_call_arr
is
None
or
len
(
self
.
prev_tool_call_arr
)
==
0
:
logger
.
debug
(
"attempting to close tool call, but no tool call"
)
if
len
(
extracted_tool_calls
.
tool_calls
)
==
0
:
logger
.
warning
(
"Failed to extract any tool calls."
)
return
None
diff
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
if
diff
:
diff
=
(
diff
.
encode
(
"utf-8"
).
decode
(
"unicode_escape"
)
if
diff
is
str
else
diff
)
if
'"}'
not
in
delta_text
:
return
None
end_loc
=
delta_text
.
rindex
(
'"}'
)
diff
=
delta_text
[:
end_loc
]
+
'"}'
logger
.
debug
(
"Finishing tool and found diff that had not "
"been streamed yet: %s"
,
diff
,
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
diff
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
),
)
])
# case -- otherwise we're just generating text
else
:
text
=
delta_text
.
replace
(
self
.
tool_call_start_token
,
""
)
text
=
text
.
replace
(
self
.
tool_call_end_token
,
""
)
delta
=
DeltaMessage
(
tool_calls
=
[],
content
=
text
)
return
delta
current_tool_call
=
dict
()
if
tool_call_portion
:
current_tool_call_matches
=
(
self
.
stream_tool_call_portion_regex
.
match
(
tool_call_portion
))
if
current_tool_call_matches
:
tool_id
,
tool_args
=
(
current_tool_call_matches
.
groups
())
tool_name
=
tool_id
.
split
(
'.'
)[
1
].
split
(
':'
)[
0
]
current_tool_call
[
'id'
]
=
tool_id
current_tool_call
[
"name"
]
=
tool_name
current_tool_call
[
"arguments"
]
=
tool_args
else
:
current_tool_call_name_matches
=
(
self
.
stream_tool_call_name_regex
.
match
(
tool_call_portion
))
if
current_tool_call_name_matches
:
tool_id_str
,
=
current_tool_call_name_matches
.
groups
()
tool_name
=
tool_id_str
.
split
(
'.'
)[
1
].
split
(
':'
)[
0
]
current_tool_call
[
'id'
]
=
tool_id_str
current_tool_call
[
"name"
]
=
tool_name
current_tool_call
[
"arguments"
]
=
""
else
:
logger
.
debug
(
"Not enough token"
)
return
None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if
not
self
.
current_tool_name_sent
:
if
current_tool_call
is
None
:
return
None
function_name
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"name"
)
tool_id
=
current_tool_call
.
get
(
"id"
)
if
function_name
:
self
.
current_tool_name_sent
=
True
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
tool_id
,
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
),
)
])
else
:
return
None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if
tool_call_portion
is
None
:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta
=
(
DeltaMessage
(
content
=
delta_text
)
if
text_portion
is
not
None
else
None
)
return
delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger
.
debug
(
"Trying to parse current tool call with ID %s"
,
self
.
current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
self
.
prev_tool_call_arr
.
append
({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
logger
.
debug
(
"diffing old arguments: %s"
,
prev_arguments
)
logger
.
debug
(
"against new ones: %s"
,
cur_arguments
)
# case -- no arguments have been created yet. skip sending a delta.
if
not
cur_arguments
and
not
prev_arguments
:
logger
.
debug
(
"Skipping text %s - no arguments"
,
delta_text
)
delta
=
None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif
not
cur_arguments
and
prev_arguments
:
logger
.
error
(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta
=
None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif
cur_arguments
and
not
prev_arguments
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
cur_arguments
).
model_dump
(
exclude_none
=
True
),
)
])
tool_call
=
extracted_tool_calls
.
tool_calls
[
0
]
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
{
"name"
:
tool_call
.
function
.
name
,
"arguments"
:
json
.
loads
(
tool_call
.
function
.
arguments
)
}
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
=
cur_arguments
# last case -- we have an update to existing arguments.
elif
cur_arguments
and
prev_arguments
:
if
(
isinstance
(
delta_text
,
str
)
and
cur_arguments
!=
prev_arguments
and
len
(
cur_arguments
)
>
len
(
prev_arguments
)
and
cur_arguments
.
startswith
(
prev_arguments
)):
delta_arguments
=
cur_arguments
[
len
(
prev_arguments
):]
logger
.
debug
(
"got diff %s"
,
delta_text
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
self
.
current_tool_id
]
=
tool_call
.
function
.
arguments
delta
=
DeltaMessage
(
content
=
extracted_tool_calls
.
content
,
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
id
=
tool_call
.
id
,
type
=
tool_call
.
type
,
function
=
DeltaFunctionCall
(
arguments
=
delta_arguments
).
model_dump
(
exclude_none
=
True
),
)
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
))
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
=
cur_arguments
else
:
delta
=
None
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if
self
.
current_tool_id
==
len
(
self
.
prev_tool_call_arr
)
-
1
:
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
current_tool_call
else
:
self
.
prev_tool_call_arr
.
append
(
current_tool_call
)
self
.
current_tool_id
+=
1
self
.
_buffer
=
cur_text
[
end_idx
+
len
(
self
.
tool_call_end_token
):]
return
delta
except
Exception
:
logger
.
exception
(
"Error trying to handle streaming tool call."
)
return
None
# do not stream a delta. skip this token ID.
self
.
_buffer
=
cur_text
[
start_idx
:]
return
DeltaMessage
(
content
=
cur_text
[:
start_idx
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment