Unverified Commit 38c361f9 authored by Joachim Studnia's avatar Joachim Studnia Committed by GitHub
Browse files
parent bb62dda2
...@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer( ...@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
"single_tool_add", "single_tool_add",
"single_tool_weather", "single_tool_weather",
"multiple_tool_calls", "multiple_tool_calls",
"complex",
"wrong_json",
], ],
argnames=["model_output", "expected_tool_calls", "expected_content"], argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[ argvalues=[
...@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer( ...@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
], ],
None, None,
), ),
(
# Complex
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
)[:-2],
)
)
],
"hi{hi",
),
(
# Wrong json
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"hi{hi",
),
], ],
) )
def test_extract_tool_calls( def test_extract_tool_calls(
...@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming( ...@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
), ),
( (
# Complex # Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 """hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[ [
ToolCall( ToolCall(
function=FunctionCall( function=FunctionCall(
...@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming( ...@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
) )
) )
], ],
"", "hi{hi",
), ),
], ],
) )
......
...@@ -131,78 +131,105 @@ class MistralToolParser(ToolParser): ...@@ -131,78 +131,105 @@ class MistralToolParser(ToolParser):
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ExtractedToolCallInformation: ) -> ExtractedToolCallInformation:
""" """
Extract the tool calls from a complete model response. Requires Extract the tool calls from a complete model response.
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes! Content and tool calls formatting depends on the Mistral's tokenizer version
used to train the model:
- < v11: `content[BOT] [{tool_call1},{tool_call2}]`
- >= v11: `content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}`
with [BOT] the tool call token.
Note:
For tokenizer versions >= v11, tool calls with arguments wrongly formatted
are still returned as tool calls. This is to allow the model to know it
tried to make a tool call. It reduces chance of another failure and
prevents that the context is filled with tool calls wrongly placed in
assistant message contents.
""" """
# case -- if a tool call token is not present, return a text response # If the tool call token is not present, return a text response
if self.bot_token not in model_output: if self.bot_token not in model_output:
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output tools_called=False, tool_calls=[], content=model_output
) )
# first remove the BOT token content_and_raw_tool_calls = model_output.split(self.bot_token)
tool_content = model_output.replace(self.bot_token, "").strip() content = content_and_raw_tool_calls[0]
raw_tool_calls = content_and_raw_tool_calls[1:]
# >= v11: content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}
if not self._is_pre_v11:
tool_calls = []
for raw_tool_call in raw_tool_calls:
if "{" not in raw_tool_call:
continue
end_name = raw_tool_call.find("{")
tool_name, args = (
raw_tool_call[:end_name],
raw_tool_call[end_name:],
)
try: tool_calls.append({"name": tool_name, "arguments": args})
# < v11: content[BOT] [{tool_call1},{tool_call2}]
else:
if len(raw_tool_calls) != 1:
raise ValueError(
"Only one BOT token should have been outputted, "
f"but got {model_output}."
)
stringified_tool_calls = raw_tool_calls[0].strip()
try: try:
if not self._is_pre_v11: tool_calls = json.loads(stringified_tool_calls)
function_call_arr = []
for single_tool_content in model_output.split(self.bot_token):
if "{" not in single_tool_content:
continue
end_name = single_tool_content.find("{")
fn_name, args = (
single_tool_content[:end_name],
single_tool_content[end_name:],
)
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
else:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError: except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call. # use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained # NOTE: This use case should not happen if the model is trained
# correctly. It's an easy possible fix so it's included, but # correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls # can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0] try:
function_call_arr = json.loads(raw_tool_call) raw_tool_call = self.tool_call_regex.findall(
stringified_tool_calls
# Tool Call )[0]
tool_calls: list[MistralToolCall] = [ tool_calls = json.loads(raw_tool_call)
MistralToolCall( except (IndexError, json.JSONDecodeError):
type="function", logger.exception("Error in extracting tool call from response: {e}")
function=FunctionCall( # If raw decoding and decoding post regex rule fails, then just
name=raw_function_call["name"], # return content.
# function call args are JSON but as a string return ExtractedToolCallInformation(
arguments=json.dumps( tools_called=False,
raw_function_call["arguments"], ensure_ascii=False tool_calls=[],
content=stringified_tool_calls,
)
else:
tool_calls = [
{
"name": tool_call["name"],
"arguments": json.dumps(
tool_call["arguments"], ensure_ascii=False
), ),
), }
) for tool_call in tool_calls
for raw_function_call in function_call_arr ]
]
# get any content before the tool call mistral_tool_calls: list[MistralToolCall] = [
content = model_output.split(self.bot_token)[0] MistralToolCall(
return ExtractedToolCallInformation( type="function",
tools_called=True, function=FunctionCall(
tool_calls=tool_calls, name=tool_call["name"],
content=content if len(content) > 0 else None, arguments=tool_call["arguments"],
),
) )
for tool_call in tool_calls
]
except Exception: return ExtractedToolCallInformation(
logger.exception("Error in extracting tool call from response.") tools_called=True,
# return information to just treat the tool call as regular JSON tool_calls=mistral_tool_calls,
return ExtractedToolCallInformation( content=content if len(content) > 0 else None,
tools_called=False, tool_calls=[], content=tool_content )
)
def extract_tool_calls_streaming( def extract_tool_calls_streaming(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment