Unverified Commit 2ec18f5d authored by Thomas's avatar Thomas Committed by GitHub
Browse files

[Bugfix][Parser] Fix Mistral tool parser for HF tokenizers (#39294)


Signed-off-by: default avatarthomasmaindron <thomasmaindron@users.noreply.github.com>
Co-authored-by: default avatarthomasmaindron <thomasmaindron@users.noreply.github.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
parent 6dec49f2
......@@ -91,7 +91,12 @@ class MistralToolCall(ToolCall):
def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11)
if is_mistral_tokenizer(model_tokenizer):
return model_tokenizer.version < 11
# For HF tokenizers, check if [ARGS] token exists in vocab
# which indicates a v11+ equivalent tokenizer
vocab: dict[str, int] = getattr(model_tokenizer, "get_vocab", lambda: {})()
return "[ARGS]" not in vocab
@dataclass
......@@ -139,7 +144,8 @@ class MistralToolParser(ToolParser):
self.current_tool_name: str | None = None
self.current_tool_mistral_id: str | None = None
self.starting_new_tool = False
if _is_pre_v11_tokeniser(self.model_tokenizer):
self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
if self._is_pre_v11:
self.parse_coro = ijson.parse_coro(
self.update_stream_state_pre_v11_tokenizer()
)
......@@ -147,7 +153,6 @@ class MistralToolParser(ToolParser):
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
if self.bot_token_id is None:
raise RuntimeError(
......@@ -470,6 +475,8 @@ class MistralToolParser(ToolParser):
raw_tool_call[end_name:],
)
# HF tokenizers may include [ARGS] in the text
tool_name = tool_name.replace("[ARGS]", "")
tool_calls.append({"name": tool_name, "arguments": args})
# < v11: content[BOT] [{tool_call1},{tool_call2}]
......@@ -558,7 +565,7 @@ class MistralToolParser(ToolParser):
# if the tool call token IS in the tokens generated so far, that
# means we're parsing as tool calls now
try:
if _is_pre_v11_tokeniser(self.model_tokenizer):
if self._is_pre_v11:
return self._extract_tool_calls_streaming_pre_v11_tokenizer(
delta_text=delta_text,
delta_token_ids=delta_token_ids,
......@@ -646,6 +653,8 @@ class MistralToolParser(ToolParser):
tool_id = MistralToolCall.generate_random_id()
delta_function_name = delta_text.split("{")[0]
self.current_tool_name += delta_function_name
# HF tokenizers may include [ARGS] in the text
self.current_tool_name = self.current_tool_name.replace("[ARGS]", "")
delta_text = delta_text[len(delta_function_name) :]
self.streaming_state = StreamingState.PARSING_ARGUMENTS
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment