Unverified Commit cfa49213 authored by Doug Smith's avatar Doug Smith Committed by GitHub
Browse files

[Bugfix][Parser] Fix Mistral pre-v11 tool parser failing on trailing model output (#40531)


Signed-off-by: default avatardougbtv <dosmith@redhat.com>
Signed-off-by: default avatarDoug Smith <dougbtv@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: default avatarFlora Feng <4florafeng@gmail.com>
parent 29f64c5f
......@@ -24,7 +24,6 @@ from mistral_common.protocol.instruct.tool_calls import (
ToolChoiceEnum as MistralToolChoiceEnum,
)
from partial_json_parser.core.options import Allow
from pydantic import ValidationError
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
......@@ -250,6 +249,7 @@ def test_extract_tool_calls_no_tools(parser_fixture, request):
"argument_before_name_and_name_in_argument",
"multiple_tools",
"content_before_tool",
"trailing_data_after_json",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
......@@ -338,6 +338,24 @@ def test_extract_tool_calls_no_tools(parser_fixture, request):
],
"Hello",
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\nextra trailing data""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
None,
),
],
)
def test_extract_tool_calls_pre_v11_tokenizer(
......@@ -366,19 +384,22 @@ def test_extract_tool_calls_pre_v11_multiple_bot_tokens_raises(
)
def test_extract_tool_calls_pre_v11_regex_fallback_raises(
def test_extract_tool_calls_pre_v11_regex_fallback(
mistral_pre_v11_tool_parser,
):
"""The regex fallback path finds valid JSON but does not re-serialize
the `arguments` dict to a string, causing a Pydantic
`ValidationError` when constructing `FunctionCall`."""
"""The regex fallback path finds valid JSON via regex when the primary
raw_decode fails on leading junk. It should re-serialize arguments
and return a valid tool call."""
model_output = (
'[TOOL_CALLS] junk [{"name": "add", "arguments":{"a": 1, "b": 2}}] trail'
)
with pytest.raises(ValidationError):
mistral_pre_v11_tool_parser.extract_tool_calls(
result = mistral_pre_v11_tool_parser.extract_tool_calls(
model_output, request=_DUMMY_REQUEST
)
assert result.tools_called
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "add"
assert result.tool_calls[0].function.arguments == json.dumps({"a": 1, "b": 2})
def test_extract_tool_calls_pre_v11_regex_fallback_fails(
......@@ -579,6 +600,7 @@ def _test_extract_tool_calls_streaming(
"argument_before_name",
"argument_before_name_and_name_in_argument",
"multiple_tools",
"trailing_data_after_json",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
......@@ -668,6 +690,24 @@ def _test_extract_tool_calls_streaming(
],
"",
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\nextra trailing data""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"\nextra trailing data",
),
],
)
def test_extract_tool_calls_streaming_pre_v11_tokenizer(
......
......@@ -479,21 +479,28 @@ class MistralToolParser(ToolParser):
)
stringified_tool_calls = raw_tool_calls[0].strip()
try:
tool_calls = json.loads(stringified_tool_calls)
# Use raw_decode to parse the first valid JSON value,
# ignoring trailing tokens the model may emit after
# the tool call array.
tool_calls, _ = json.JSONDecoder().raw_decode(stringified_tool_calls)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
try:
raw_tool_call = self.tool_call_regex.findall(
stringified_tool_calls
)[0]
tool_calls = json.loads(raw_tool_call)
tool_calls = [
{
"name": tool_call["name"],
"arguments": json.dumps(
tool_call.get("arguments", {}),
ensure_ascii=False,
),
}
for tool_call in tool_calls
]
except (IndexError, json.JSONDecodeError):
logger.exception("Error in extracting tool call from response.")
# If raw decoding and decoding post regex rule fails, then just
# return content.
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
......@@ -504,7 +511,8 @@ class MistralToolParser(ToolParser):
{
"name": tool_call["name"],
"arguments": json.dumps(
tool_call["arguments"], ensure_ascii=False
tool_call.get("arguments", {}),
ensure_ascii=False,
),
}
for tool_call in tool_calls
......@@ -515,7 +523,7 @@ class MistralToolParser(ToolParser):
type="function",
function=FunctionCall(
name=tool_call["name"],
arguments=tool_call["arguments"],
arguments=tool_call.get("arguments", "{}"),
),
)
for tool_call in tool_calls
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment