Unverified Commit cfa49213 authored by Doug Smith's avatar Doug Smith Committed by GitHub
Browse files

[Bugfix][Parser] Fix Mistral pre-v11 tool parser failing on trailing model output (#40531)


Signed-off-by: default avatardougbtv <dosmith@redhat.com>
Signed-off-by: default avatarDoug Smith <dougbtv@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: default avatarFlora Feng <4florafeng@gmail.com>
parent 29f64c5f
...@@ -24,7 +24,6 @@ from mistral_common.protocol.instruct.tool_calls import ( ...@@ -24,7 +24,6 @@ from mistral_common.protocol.instruct.tool_calls import (
ToolChoiceEnum as MistralToolChoiceEnum, ToolChoiceEnum as MistralToolChoiceEnum,
) )
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from pydantic import ValidationError
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
...@@ -250,6 +249,7 @@ def test_extract_tool_calls_no_tools(parser_fixture, request): ...@@ -250,6 +249,7 @@ def test_extract_tool_calls_no_tools(parser_fixture, request):
"argument_before_name_and_name_in_argument", "argument_before_name_and_name_in_argument",
"multiple_tools", "multiple_tools",
"content_before_tool", "content_before_tool",
"trailing_data_after_json",
], ],
argnames=["model_output", "expected_tool_calls", "expected_content"], argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[ argvalues=[
...@@ -338,6 +338,24 @@ def test_extract_tool_calls_no_tools(parser_fixture, request): ...@@ -338,6 +338,24 @@ def test_extract_tool_calls_no_tools(parser_fixture, request):
], ],
"Hello", "Hello",
), ),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\nextra trailing data""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
None,
),
], ],
) )
def test_extract_tool_calls_pre_v11_tokenizer( def test_extract_tool_calls_pre_v11_tokenizer(
...@@ -366,19 +384,22 @@ def test_extract_tool_calls_pre_v11_multiple_bot_tokens_raises( ...@@ -366,19 +384,22 @@ def test_extract_tool_calls_pre_v11_multiple_bot_tokens_raises(
) )
def test_extract_tool_calls_pre_v11_regex_fallback_raises( def test_extract_tool_calls_pre_v11_regex_fallback(
mistral_pre_v11_tool_parser, mistral_pre_v11_tool_parser,
): ):
"""The regex fallback path finds valid JSON but does not re-serialize """The regex fallback path finds valid JSON via regex when the primary
the `arguments` dict to a string, causing a Pydantic raw_decode fails on leading junk. It should re-serialize arguments
`ValidationError` when constructing `FunctionCall`.""" and return a valid tool call."""
model_output = ( model_output = (
'[TOOL_CALLS] junk [{"name": "add", "arguments":{"a": 1, "b": 2}}] trail' '[TOOL_CALLS] junk [{"name": "add", "arguments":{"a": 1, "b": 2}}] trail'
) )
with pytest.raises(ValidationError): result = mistral_pre_v11_tool_parser.extract_tool_calls(
mistral_pre_v11_tool_parser.extract_tool_calls( model_output, request=_DUMMY_REQUEST
model_output, request=_DUMMY_REQUEST )
) assert result.tools_called
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "add"
assert result.tool_calls[0].function.arguments == json.dumps({"a": 1, "b": 2})
def test_extract_tool_calls_pre_v11_regex_fallback_fails( def test_extract_tool_calls_pre_v11_regex_fallback_fails(
...@@ -579,6 +600,7 @@ def _test_extract_tool_calls_streaming( ...@@ -579,6 +600,7 @@ def _test_extract_tool_calls_streaming(
"argument_before_name", "argument_before_name",
"argument_before_name_and_name_in_argument", "argument_before_name_and_name_in_argument",
"multiple_tools", "multiple_tools",
"trailing_data_after_json",
], ],
argnames=["model_output", "expected_tool_calls", "expected_content"], argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[ argvalues=[
...@@ -668,6 +690,24 @@ def _test_extract_tool_calls_streaming( ...@@ -668,6 +690,24 @@ def _test_extract_tool_calls_streaming(
], ],
"", "",
), ),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\nextra trailing data""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"\nextra trailing data",
),
], ],
) )
def test_extract_tool_calls_streaming_pre_v11_tokenizer( def test_extract_tool_calls_streaming_pre_v11_tokenizer(
......
...@@ -479,21 +479,28 @@ class MistralToolParser(ToolParser): ...@@ -479,21 +479,28 @@ class MistralToolParser(ToolParser):
) )
stringified_tool_calls = raw_tool_calls[0].strip() stringified_tool_calls = raw_tool_calls[0].strip()
try: try:
tool_calls = json.loads(stringified_tool_calls) # Use raw_decode to parse the first valid JSON value,
# ignoring trailing tokens the model may emit after
# the tool call array.
tool_calls, _ = json.JSONDecoder().raw_decode(stringified_tool_calls)
except json.JSONDecodeError: except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
try: try:
raw_tool_call = self.tool_call_regex.findall( raw_tool_call = self.tool_call_regex.findall(
stringified_tool_calls stringified_tool_calls
)[0] )[0]
tool_calls = json.loads(raw_tool_call) tool_calls = json.loads(raw_tool_call)
tool_calls = [
{
"name": tool_call["name"],
"arguments": json.dumps(
tool_call.get("arguments", {}),
ensure_ascii=False,
),
}
for tool_call in tool_calls
]
except (IndexError, json.JSONDecodeError): except (IndexError, json.JSONDecodeError):
logger.exception("Error in extracting tool call from response.") logger.exception("Error in extracting tool call from response.")
# If raw decoding and decoding post regex rule fails, then just
# return content.
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=False, tools_called=False,
tool_calls=[], tool_calls=[],
...@@ -504,7 +511,8 @@ class MistralToolParser(ToolParser): ...@@ -504,7 +511,8 @@ class MistralToolParser(ToolParser):
{ {
"name": tool_call["name"], "name": tool_call["name"],
"arguments": json.dumps( "arguments": json.dumps(
tool_call["arguments"], ensure_ascii=False tool_call.get("arguments", {}),
ensure_ascii=False,
), ),
} }
for tool_call in tool_calls for tool_call in tool_calls
...@@ -515,7 +523,7 @@ class MistralToolParser(ToolParser): ...@@ -515,7 +523,7 @@ class MistralToolParser(ToolParser):
type="function", type="function",
function=FunctionCall( function=FunctionCall(
name=tool_call["name"], name=tool_call["name"],
arguments=tool_call["arguments"], arguments=tool_call.get("arguments", "{}"),
), ),
) )
for tool_call in tool_calls for tool_call in tool_calls
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment