Unverified Commit f20f9f06 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[mistral_common] Add v11 tokenizer (#19193)


Signed-off-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 9bc8bb07
......@@ -44,11 +44,17 @@ class MistralToolCall(ToolCall):
return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
return isinstance(model_tokenizer, MistralTokenizer) \
and model_tokenizer.version >= 11
@ToolParserManager.register_module("mistral")
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
examples/tool_chat_template_mistral.jinja template.
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
- the examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
......@@ -70,6 +76,12 @@ class MistralToolParser(ToolParser):
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
re.DOTALL)
else:
self.fn_name_regex = None
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
......@@ -109,10 +121,24 @@ class MistralToolParser(ToolParser):
tool_content = model_output.replace(self.bot_token, "").strip()
try:
# we first try to directly load the json as parsing very nested
# jsons is difficult
try:
if self.fn_name_regex:
matches = self.fn_name_regex.findall(tool_content)
function_call_arr = []
for match in matches:
fn_name = match[0]
args = match[1]
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append({
"name": fn_name,
"arguments": json.loads(args)
})
else:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
......
......@@ -187,6 +187,8 @@ class MistralTokenizer(TokenizerBase):
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
_mistral_version_str = self.instruct.tokenizer.version.value
self.version: int = int(_mistral_version_str.split("v")[-1])
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
from mistral_common.tokens.tokenizers.tekken import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment