Unverified Commit 9ea7d670 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Bugfix] Fix Qwen3 tool parser for Responses API tools (#38848)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 7b80cd8a
...@@ -5,6 +5,7 @@ import json ...@@ -5,6 +5,7 @@ import json
from collections.abc import Generator from collections.abc import Generator
import pytest import pytest
from openai.types.responses.function_tool import FunctionTool
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
...@@ -49,15 +50,7 @@ def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, req ...@@ -49,15 +50,7 @@ def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, req
return qwen3_xml_tool_parser return qwen3_xml_tool_parser
@pytest.fixture WEATHER_PARAMS = {
def sample_tools():
return [
ChatCompletionToolsParam(
type="function",
function={
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"city": {"type": "string", "description": "The city name"}, "city": {"type": "string", "description": "The city name"},
...@@ -65,7 +58,28 @@ def sample_tools(): ...@@ -65,7 +58,28 @@ def sample_tools():
"unit": {"type": "string", "enum": ["fahrenheit", "celsius"]}, "unit": {"type": "string", "enum": ["fahrenheit", "celsius"]},
}, },
"required": ["city", "state"], "required": ["city", "state"],
}
AREA_PARAMS = {
"type": "object",
"properties": {
"shape": {"type": "string"},
"dimensions": {"type": "object"},
"precision": {"type": "integer"},
}, },
}
@pytest.fixture(params=["chat_completion", "responses_api"])
def sample_tools(request):
if request.param == "chat_completion":
return [
ChatCompletionToolsParam(
type="function",
function={
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": WEATHER_PARAMS,
}, },
), ),
ChatCompletionToolsParam( ChatCompletionToolsParam(
...@@ -73,17 +87,25 @@ def sample_tools(): ...@@ -73,17 +87,25 @@ def sample_tools():
function={ function={
"name": "calculate_area", "name": "calculate_area",
"description": "Calculate area of a shape", "description": "Calculate area of a shape",
"parameters": { "parameters": AREA_PARAMS,
"type": "object",
"properties": {
"shape": {"type": "string"},
"dimensions": {"type": "object"},
"precision": {"type": "integer"},
},
},
}, },
), ),
] ]
else:
return [
FunctionTool(
type="function",
name="get_current_weather",
description="Get the current weather",
parameters=WEATHER_PARAMS,
),
FunctionTool(
type="function",
name="calculate_area",
description="Calculate area of a shape",
parameters=AREA_PARAMS,
),
]
def assert_tool_calls( def assert_tool_calls(
...@@ -337,12 +359,11 @@ circle ...@@ -337,12 +359,11 @@ circle
) )
def test_extract_tool_calls( def test_extract_tool_calls(
qwen3_tool_parser_parametrized, qwen3_tool_parser_parametrized,
sample_tools,
model_output, model_output,
expected_tool_calls, expected_tool_calls,
expected_content, expected_content,
): ):
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request model_output, request=request
) )
...@@ -354,7 +375,7 @@ def test_extract_tool_calls( ...@@ -354,7 +375,7 @@ def test_extract_tool_calls(
def test_extract_tool_calls_fallback_no_tags( def test_extract_tool_calls_fallback_no_tags(
qwen3_tool_parser_parametrized, sample_tools qwen3_tool_parser_parametrized,
): ):
"""Test fallback parsing when XML tags are missing""" """Test fallback parsing when XML tags are missing"""
model_output = """<function=get_current_weather> model_output = """<function=get_current_weather>
...@@ -366,7 +387,7 @@ TX ...@@ -366,7 +387,7 @@ TX
</parameter> </parameter>
</function>""" </function>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request model_output, request=request
) )
...@@ -607,13 +628,12 @@ circle ...@@ -607,13 +628,12 @@ circle
def test_extract_tool_calls_streaming( def test_extract_tool_calls_streaming(
qwen3_tool_parser_parametrized, qwen3_tool_parser_parametrized,
qwen3_tokenizer, qwen3_tokenizer,
sample_tools,
model_output, model_output,
expected_tool_calls, expected_tool_calls,
expected_content, expected_content,
): ):
"""Test incremental streaming behavior including typed parameters""" """Test incremental streaming behavior including typed parameters"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
other_content = "" other_content = ""
tool_states = {} # Track state per tool index tool_states = {} # Track state per tool index
...@@ -683,7 +703,7 @@ def test_extract_tool_calls_streaming( ...@@ -683,7 +703,7 @@ def test_extract_tool_calls_streaming(
def test_extract_tool_calls_missing_closing_parameter_tag( def test_extract_tool_calls_missing_closing_parameter_tag(
qwen3_tool_parser_parametrized, sample_tools qwen3_tool_parser_parametrized,
): ):
"""Test handling of missing closing </parameter> tag""" """Test handling of missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML # Using get_current_weather from sample_tools but with malformed XML
...@@ -701,7 +721,7 @@ fahrenheit ...@@ -701,7 +721,7 @@ fahrenheit
</function> </function>
</tool_call>""" </tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request model_output, request=request
) )
...@@ -725,7 +745,7 @@ fahrenheit ...@@ -725,7 +745,7 @@ fahrenheit
def test_extract_tool_calls_streaming_missing_closing_tag( def test_extract_tool_calls_streaming_missing_closing_tag(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools qwen3_tool_parser_parametrized, qwen3_tokenizer
): ):
"""Test streaming with missing closing </parameter> tag""" """Test streaming with missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML # Using get_current_weather from sample_tools but with malformed XML
...@@ -743,7 +763,7 @@ fahrenheit ...@@ -743,7 +763,7 @@ fahrenheit
</function> </function>
</tool_call>""" </tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
other_content = "" other_content = ""
tool_states = {} tool_states = {}
...@@ -800,7 +820,7 @@ fahrenheit ...@@ -800,7 +820,7 @@ fahrenheit
def test_extract_tool_calls_streaming_incremental( def test_extract_tool_calls_streaming_incremental(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools qwen3_tool_parser_parametrized, qwen3_tokenizer
): ):
"""Test that streaming is truly incremental""" """Test that streaming is truly incremental"""
model_output = """I'll check the weather.<tool_call> model_output = """I'll check the weather.<tool_call>
...@@ -814,7 +834,7 @@ TX ...@@ -814,7 +834,7 @@ TX
</function> </function>
</tool_call>""" </tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
chunks = [] chunks = []
for delta_message in stream_delta_message_generator( for delta_message in stream_delta_message_generator(
...@@ -897,7 +917,7 @@ def test_extract_tool_calls_complex_type_with_single_quote( ...@@ -897,7 +917,7 @@ def test_extract_tool_calls_complex_type_with_single_quote(
def test_extract_tool_calls_streaming_missing_opening_tag( def test_extract_tool_calls_streaming_missing_opening_tag(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools qwen3_tool_parser_parametrized, qwen3_tokenizer
): ):
"""Test streaming with missing opening <tool_call> tag """Test streaming with missing opening <tool_call> tag
...@@ -919,7 +939,7 @@ fahrenheit ...@@ -919,7 +939,7 @@ fahrenheit
</function> </function>
</tool_call>""" </tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
other_content = "" other_content = ""
tool_states = {} tool_states = {}
...@@ -976,7 +996,7 @@ fahrenheit ...@@ -976,7 +996,7 @@ fahrenheit
assert args["unit"] == "fahrenheit" assert args["unit"] == "fahrenheit"
def test_malformed_xml_no_gt_delimiter(qwen3_tool_parser, sample_tools): def test_malformed_xml_no_gt_delimiter(qwen3_tool_parser):
"""Regression: malformed XML without '>' must not crash (PR #36774).""" """Regression: malformed XML without '>' must not crash (PR #36774)."""
model_output = ( model_output = (
"<tool_call>\n" "<tool_call>\n"
...@@ -986,14 +1006,14 @@ def test_malformed_xml_no_gt_delimiter(qwen3_tool_parser, sample_tools): ...@@ -986,14 +1006,14 @@ def test_malformed_xml_no_gt_delimiter(qwen3_tool_parser, sample_tools):
"</tool_call>" "</tool_call>"
) )
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
result = qwen3_tool_parser.extract_tool_calls(model_output, request=request) result = qwen3_tool_parser.extract_tool_calls(model_output, request=request)
assert result is not None assert result is not None
assert isinstance(result.tool_calls, list) assert isinstance(result.tool_calls, list)
assert all(tc is not None for tc in result.tool_calls) assert all(tc is not None for tc in result.tool_calls)
def test_none_tool_calls_filtered(qwen3_tool_parser, sample_tools): def test_none_tool_calls_filtered(qwen3_tool_parser):
"""Regression: None tool calls filtered from output (PR #36774).""" """Regression: None tool calls filtered from output (PR #36774)."""
model_output = ( model_output = (
"<tool_call>\n" "<tool_call>\n"
...@@ -1008,7 +1028,7 @@ def test_none_tool_calls_filtered(qwen3_tool_parser, sample_tools): ...@@ -1008,7 +1028,7 @@ def test_none_tool_calls_filtered(qwen3_tool_parser, sample_tools):
"</tool_call>" "</tool_call>"
) )
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
result = qwen3_tool_parser.extract_tool_calls(model_output, request=request) result = qwen3_tool_parser.extract_tool_calls(model_output, request=request)
assert all(tc is not None for tc in result.tool_calls) assert all(tc is not None for tc in result.tool_calls)
assert result.tools_called assert result.tools_called
...@@ -1058,11 +1078,9 @@ def test_anyof_parameter_not_double_encoded(qwen3_tokenizer): ...@@ -1058,11 +1078,9 @@ def test_anyof_parameter_not_double_encoded(qwen3_tokenizer):
assert args["data"] == {"key": "value", "count": 42} assert args["data"] == {"key": "value", "count": 42}
def test_streaming_multi_param_single_chunk( def test_streaming_multi_param_single_chunk(qwen3_tool_parser, qwen3_tokenizer):
qwen3_tool_parser, qwen3_tokenizer, sample_tools
):
"""Regression: speculative decode delivering multiple params at once (PR #35615).""" """Regression: speculative decode delivering multiple params at once (PR #35615)."""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[])
deltas = [ deltas = [
"<tool_call>", "<tool_call>",
......
...@@ -25,6 +25,7 @@ from vllm.tool_parsers.abstract_tool_parser import ( ...@@ -25,6 +25,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
Tool, Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import find_tool_properties
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -109,28 +110,6 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -109,28 +110,6 @@ class Qwen3CoderToolParser(ToolParser):
self.accumulated_params = {} self.accumulated_params = {}
self.streaming_request = None self.streaming_request = None
def _get_arguments_config(self, func_name: str, tools: list[Tool] | None) -> dict:
"""Extract argument configuration for a function."""
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function") and hasattr(config.function, "name")
):
continue
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.debug("Tool '%s' is not defined in the tools list.", func_name)
return {}
def _convert_param_value( def _convert_param_value(
self, param_value: str, param_name: str, param_config: dict, func_name: str self, param_value: str, param_name: str, param_config: dict, func_name: str
) -> Any: ) -> Any:
...@@ -243,16 +222,14 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -243,16 +222,14 @@ class Qwen3CoderToolParser(ToolParser):
) )
return param_value return param_value
def _parse_xml_function_call( def _parse_xml_function_call(self, function_call_str: str) -> ToolCall | None:
self, function_call_str: str, tools: list[Tool] | None
) -> ToolCall | None:
# Extract function name # Extract function name
end_index = function_call_str.find(">") end_index = function_call_str.find(">")
# If there's no ">" character, this is not a valid xml function call # If there's no ">" character, this is not a valid xml function call
if end_index == -1: if end_index == -1:
return None return None
function_name = function_call_str[:end_index] function_name = function_call_str[:end_index]
param_config = self._get_arguments_config(function_name, tools) param_config = find_tool_properties(self.tools, function_name)
parameters = function_call_str[end_index + 1 :] parameters = function_call_str[end_index + 1 :]
param_dict = {} param_dict = {}
for match_text in self.tool_call_parameter_regex.findall(parameters): for match_text in self.tool_call_parameter_regex.findall(parameters):
...@@ -314,7 +291,7 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -314,7 +291,7 @@ class Qwen3CoderToolParser(ToolParser):
) )
tool_calls = [ tool_calls = [
self._parse_xml_function_call(function_call_str, self.tools) self._parse_xml_function_call(function_call_str)
for function_call_str in function_calls for function_call_str in function_calls
] ]
# Populate prev_tool_call_arr for serving layer to set finish_reason # Populate prev_tool_call_arr for serving layer to set finish_reason
...@@ -605,9 +582,8 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -605,9 +582,8 @@ class Qwen3CoderToolParser(ToolParser):
self.current_param_name = current_param_name self.current_param_name = current_param_name
self.accumulated_params[current_param_name] = param_value self.accumulated_params[current_param_name] = param_value
param_config = self._get_arguments_config( param_config = find_tool_properties(
self.current_function_name or "", self.tools, self.current_function_name or ""
self.tools,
) )
converted_value = self._convert_param_value( converted_value = self._convert_param_value(
...@@ -666,7 +642,6 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -666,7 +642,6 @@ class Qwen3CoderToolParser(ToolParser):
try: try:
parsed_tool = self._parse_xml_function_call( parsed_tool = self._parse_xml_function_call(
func_content, func_content,
self.tools,
) )
if parsed_tool and self.current_tool_index < len( if parsed_tool and self.current_tool_index < len(
self.prev_tool_call_arr self.prev_tool_call_arr
......
...@@ -26,6 +26,7 @@ from vllm.tool_parsers.abstract_tool_parser import ( ...@@ -26,6 +26,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
Tool, Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import find_tool_properties
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1000,33 +1001,11 @@ class StreamingXMLToolCallParser: ...@@ -1000,33 +1001,11 @@ class StreamingXMLToolCallParser:
if not self.tools or not self.current_function_name: if not self.tools or not self.current_function_name:
return "string" return "string"
for tool in self.tools: properties = find_tool_properties(self.tools, self.current_function_name)
if not hasattr(tool, "type") or not ( if param_name in properties and isinstance(properties[param_name], dict):
hasattr(tool, "function") and hasattr(tool.function, "name")
):
continue
if (
tool.type == "function"
and tool.function.name == self.current_function_name
):
if not hasattr(tool.function, "parameters"):
return "string"
params = tool.function.parameters
if isinstance(params, dict) and "properties" in params:
properties = params["properties"]
if param_name in properties and isinstance(
properties[param_name], dict
):
return self.repair_param_type( return self.repair_param_type(
str(properties[param_name].get("type", "string")) str(properties[param_name].get("type", "string"))
) )
elif isinstance(params, dict) and param_name in params:
param_config = params[param_name]
if isinstance(param_config, dict):
return self.repair_param_type(
str(param_config.get("type", "string"))
)
break
return "string" return "string"
def repair_param_type(self, param_type: str) -> str: def repair_param_type(self, param_type: str) -> str:
......
...@@ -142,6 +142,20 @@ def _extract_tool_info( ...@@ -142,6 +142,20 @@ def _extract_tool_info(
raise TypeError(f"Unsupported tool type: {type(tool)}") raise TypeError(f"Unsupported tool type: {type(tool)}")
def find_tool_properties(
tools: list[Tool] | None,
tool_name: str,
) -> dict[str, Any]:
"""Find a tool by name and return its properties dict, or {}."""
if not tools:
return {}
for tool in tools:
name, params = _extract_tool_info(tool)
if name == tool_name:
return (params or {}).get("properties", {})
return {}
def _get_tool_schema_from_tool(tool: Tool) -> dict: def _get_tool_schema_from_tool(tool: Tool) -> dict:
name, params = _extract_tool_info(tool) name, params = _extract_tool_info(tool)
params = params if params else {"type": "object", "properties": {}} params = params if params else {"type": "object", "properties": {}}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment