Unverified Commit 38c361f9 authored by Joachim Studnia's avatar Joachim Studnia Committed by GitHub
Browse files
parent bb62dda2
...@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer( ...@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
"single_tool_add", "single_tool_add",
"single_tool_weather", "single_tool_weather",
"multiple_tool_calls", "multiple_tool_calls",
"complex",
"wrong_json",
], ],
argnames=["model_output", "expected_tool_calls", "expected_content"], argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[ argvalues=[
...@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer( ...@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
], ],
None, None,
), ),
(
# Complex
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
)[:-2],
)
)
],
"hi{hi",
),
(
# Wrong json
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"hi{hi",
),
], ],
) )
def test_extract_tool_calls( def test_extract_tool_calls(
...@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming( ...@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
), ),
( (
# Complex # Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 """hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[ [
ToolCall( ToolCall(
function=FunctionCall( function=FunctionCall(
...@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming( ...@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
) )
) )
], ],
"", "hi{hi",
), ),
], ],
) )
......
...@@ -131,79 +131,106 @@ class MistralToolParser(ToolParser): ...@@ -131,79 +131,106 @@ class MistralToolParser(ToolParser):
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ExtractedToolCallInformation: ) -> ExtractedToolCallInformation:
""" """
Extract the tool calls from a complete model response. Requires Extract the tool calls from a complete model response.
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes! Content and tool calls formatting depends on the Mistral's tokenizer version
used to train the model:
- < v11: `content[BOT] [{tool_call1},{tool_call2}]`
- >= v11: `content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}`
with [BOT] the tool call token.
Note:
For tokenizer versions >= v11, tool calls with arguments wrongly formatted
are still returned as tool calls. This is to allow the model to know it
tried to make a tool call. It reduces chance of another failure and
prevents that the context is filled with tool calls wrongly placed in
assistant message contents.
""" """
# case -- if a tool call token is not present, return a text response # If the tool call token is not present, return a text response
if self.bot_token not in model_output: if self.bot_token not in model_output:
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output tools_called=False, tool_calls=[], content=model_output
) )
# first remove the BOT token content_and_raw_tool_calls = model_output.split(self.bot_token)
tool_content = model_output.replace(self.bot_token, "").strip() content = content_and_raw_tool_calls[0]
raw_tool_calls = content_and_raw_tool_calls[1:]
try: # >= v11: content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}
try:
if not self._is_pre_v11: if not self._is_pre_v11:
function_call_arr = [] tool_calls = []
for single_tool_content in model_output.split(self.bot_token): for raw_tool_call in raw_tool_calls:
if "{" not in single_tool_content: if "{" not in raw_tool_call:
continue continue
end_name = single_tool_content.find("{") end_name = raw_tool_call.find("{")
fn_name, args = ( tool_name, args = (
single_tool_content[:end_name], raw_tool_call[:end_name],
single_tool_content[end_name:], raw_tool_call[end_name:],
) )
# fn_name is encoded outside serialized json dump tool_calls.append({"name": tool_name, "arguments": args})
# only arguments are serialized
function_call_arr.append( # < v11: content[BOT] [{tool_call1},{tool_call2}]
{"name": fn_name, "arguments": json.loads(args)}
)
else: else:
function_call_arr = json.loads(tool_content) if len(raw_tool_calls) != 1:
raise ValueError(
"Only one BOT token should have been outputted, "
f"but got {model_output}."
)
stringified_tool_calls = raw_tool_calls[0].strip()
try:
tool_calls = json.loads(stringified_tool_calls)
except json.JSONDecodeError: except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call. # use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained # NOTE: This use case should not happen if the model is trained
# correctly. It's an easy possible fix so it's included, but # correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls # can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0] try:
function_call_arr = json.loads(raw_tool_call) raw_tool_call = self.tool_call_regex.findall(
stringified_tool_calls
)[0]
tool_calls = json.loads(raw_tool_call)
except (IndexError, json.JSONDecodeError):
logger.exception("Error in extracting tool call from response: {e}")
# If raw decoding and decoding post regex rule fails, then just
# return content.
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=stringified_tool_calls,
)
else:
tool_calls = [
{
"name": tool_call["name"],
"arguments": json.dumps(
tool_call["arguments"], ensure_ascii=False
),
}
for tool_call in tool_calls
]
# Tool Call mistral_tool_calls: list[MistralToolCall] = [
tool_calls: list[MistralToolCall] = [
MistralToolCall( MistralToolCall(
type="function", type="function",
function=FunctionCall( function=FunctionCall(
name=raw_function_call["name"], name=tool_call["name"],
# function call args are JSON but as a string arguments=tool_call["arguments"],
arguments=json.dumps(
raw_function_call["arguments"], ensure_ascii=False
),
), ),
) )
for raw_function_call in function_call_arr for tool_call in tool_calls
] ]
# get any content before the tool call
content = model_output.split(self.bot_token)[0]
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=True, tools_called=True,
tool_calls=tool_calls, tool_calls=mistral_tool_calls,
content=content if len(content) > 0 else None, content=content if len(content) > 0 else None,
) )
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=tool_content
)
def extract_tool_calls_streaming( def extract_tool_calls_streaming(
self, self,
previous_text: str, previous_text: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment