Unverified Commit 5c213d28 authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

[BUGFIX] Mistral tool call parser v11+ (#30332)


Signed-off-by: default avatarjuliendenize <julien.denize@mistral.ai>
parent ee14644b
...@@ -615,6 +615,7 @@ def test_extract_tool_calls_streaming( ...@@ -615,6 +615,7 @@ def test_extract_tool_calls_streaming(
"single_tool_weather", "single_tool_weather",
"multiple_tool_calls", "multiple_tool_calls",
"content_before_tool", "content_before_tool",
"complex",
], ],
argnames=["model_output", "expected_tool_calls", "expected_content"], argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[ argvalues=[
...@@ -673,6 +674,21 @@ def test_extract_tool_calls_streaming( ...@@ -673,6 +674,21 @@ def test_extract_tool_calls_streaming(
], ],
"bla", "bla",
), ),
(
# Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"",
),
], ],
) )
def test_extract_tool_calls_streaming_one_chunk( def test_extract_tool_calls_streaming_one_chunk(
......
...@@ -99,12 +99,7 @@ class MistralToolParser(ToolParser): ...@@ -99,12 +99,7 @@ class MistralToolParser(ToolParser):
self.bot_token = "[TOOL_CALLS]" self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token) self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if not _is_pre_v11_tokeniser(self.model_tokenizer): self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
self.fn_name_regex = re.compile(
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
)
else:
self.fn_name_regex = None
if self.bot_token_id is None: if self.bot_token_id is None:
raise RuntimeError( raise RuntimeError(
...@@ -148,17 +143,18 @@ class MistralToolParser(ToolParser): ...@@ -148,17 +143,18 @@ class MistralToolParser(ToolParser):
tool_content = model_output.replace(self.bot_token, "").strip() tool_content = model_output.replace(self.bot_token, "").strip()
try: try:
# we first try to directly load the json as parsing very nested
# jsons is difficult
try: try:
if self.fn_name_regex: if not self._is_pre_v11:
function_call_arr = [] function_call_arr = []
for single_tool_content in model_output.split(self.bot_token): for single_tool_content in model_output.split(self.bot_token):
matches = self.fn_name_regex.findall(single_tool_content) if "{" not in single_tool_content:
continue
for match in matches: end_name = single_tool_content.find("{")
fn_name = match[0] fn_name, args = (
args = match[1] single_tool_content[:end_name],
single_tool_content[end_name:],
)
# fn_name is encoded outside serialized json dump # fn_name is encoded outside serialized json dump
# only arguments are serialized # only arguments are serialized
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment