Unverified Commit c6187f55 authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

Refactor MistralTokenizer (#26358)


Signed-off-by: default avatarJulien Denize <julien.denize@mistral.ai>
parent 8983e021
...@@ -145,7 +145,7 @@ Supported models: ...@@ -145,7 +145,7 @@ Supported models:
Known issues: Known issues:
1. Mistral 7B struggles to generate parallel tool calls correctly. 1. Mistral 7B struggles to generate parallel tool calls correctly.
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is 2. **For Transformers tokenization backend only**: Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
much shorter than what vLLM generates. Since an exception is thrown when this condition much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided: is not met, the following additional chat templates are provided:
...@@ -154,7 +154,14 @@ Known issues: ...@@ -154,7 +154,14 @@ Known issues:
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt * <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling. when tools are provided, that results in much better reliability when working with parallel tool calling.
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` Recommended flags:
1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend:
`--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral`
2. To use the default Transformers tokenization backend:
`--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
### Llama Models (`llama3_json`) ### Llama Models (`llama3_json`)
......
...@@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple): ...@@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple):
# Voxtral # Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData: def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import ( from mistral_common.protocol.instruct.chunk import (
AudioChunk, AudioChunk,
RawAudio, RawAudio,
TextChunk, TextChunk,
)
from mistral_common.protocol.instruct.messages import (
UserMessage, UserMessage,
) )
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
......
...@@ -32,7 +32,7 @@ pyzmq >= 25.0.0 ...@@ -32,7 +32,7 @@ pyzmq >= 25.0.0
msgspec msgspec
gguf >= 0.13.0 gguf >= 0.13.0
importlib_metadata; python_version < '3.10' importlib_metadata; python_version < '3.10'
mistral_common[image,audio] >= 1.8.2 mistral_common[image,audio] >= 1.8.5
opencv-python-headless >= 4.11.0 # required for video IO opencv-python-headless >= 4.11.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
......
...@@ -23,7 +23,7 @@ jiwer # required for audio tests ...@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test timm # required for internvl test
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
......
...@@ -29,7 +29,7 @@ torchaudio==2.8.0 ...@@ -29,7 +29,7 @@ torchaudio==2.8.0
torchvision==0.23.0 torchvision==0.23.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test open_clip_torch==2.32.0 # Required for nemotron_vl test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test
......
...@@ -474,7 +474,7 @@ mbstrdecoder==1.1.3 ...@@ -474,7 +474,7 @@ mbstrdecoder==1.1.3
# typepy # typepy
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.8.2 mistral-common==1.8.5
# via -r requirements/test.in # via -r requirements/test.in
mlflow==2.22.0 mlflow==2.22.0
# via terratorch # via terratorch
...@@ -1012,8 +1012,6 @@ sentence-transformers==3.2.1 ...@@ -1012,8 +1012,6 @@ sentence-transformers==3.2.1
# via # via
# -r requirements/test.in # -r requirements/test.in
# mteb # mteb
sentencepiece==0.2.0
# via mistral-common
setuptools==77.0.3 setuptools==77.0.3
# via # via
# lightning-utilities # lightning-utilities
......
...@@ -6,8 +6,7 @@ from collections.abc import Mapping ...@@ -6,8 +6,7 @@ from collections.abc import Mapping
from typing import Literal, Optional from typing import Literal, Optional
import pytest import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
...@@ -2119,34 +2118,9 @@ def test_apply_mistral_chat_template_thinking_chunk(): ...@@ -2119,34 +2118,9 @@ def test_apply_mistral_chat_template_thinking_chunk():
}, },
{"role": "user", "content": "Thanks, what is 3+3?"}, {"role": "user", "content": "Thanks, what is 3+3?"},
] ]
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained( mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507" "mistralai/Magistral-Small-2509"
)
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
) )
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
tokens_ids = apply_mistral_chat_template( tokens_ids = apply_mistral_chat_template(
mistral_tokenizer, messages, chat_template=None, tools=None mistral_tokenizer, messages, chat_template=None, tools=None
......
...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional
import pytest import pytest
from mistral_common.multimodal import download_image from mistral_common.multimodal import download_image
from mistral_common.protocol.instruct.messages import ImageURLChunk from mistral_common.protocol.instruct.chunk import ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
......
...@@ -6,12 +6,8 @@ import json ...@@ -6,12 +6,8 @@ import json
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from mistral_common.audio import Audio from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import ( from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
AudioChunk, from mistral_common.protocol.instruct.messages import UserMessage
RawAudio,
TextChunk,
UserMessage,
)
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.transformers_utils.tokenizer import MistralTokenizer
......
...@@ -6,7 +6,8 @@ from typing import Optional, Union ...@@ -6,7 +6,8 @@ from typing import Optional, Union
import numpy as np import numpy as np
import pytest import pytest
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
......
...@@ -9,7 +9,8 @@ from typing import Any, Union ...@@ -9,7 +9,8 @@ from typing import Any, Union
import numpy as np import numpy as np
import pytest import pytest
import torch.nn as nn import torch.nn as nn
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from tests.reasoning.utils import run_reasoning_extraction_mistral from tests.reasoning.utils import run_reasoning_extraction_mistral
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
...@@ -14,33 +12,9 @@ parser_name = "mistral" ...@@ -14,33 +12,9 @@ parser_name = "mistral"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def mistral_tokenizer(): def mistral_tokenizer():
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained( mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507" "mistralai/Magistral-Small-2509"
) )
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
return mistral_tokenizer return mistral_tokenizer
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest import pytest
from mistral_common.protocol.instruct.messages import ( from mistral_common.exceptions import InvalidMessageStructureException
AssistantMessage, from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
ToolMessage,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import (
Function,
FunctionCall,
Tool,
ToolCall,
)
from vllm.transformers_utils.tokenizers.mistral import ( from vllm.transformers_utils.tokenizers.mistral import (
make_mistral_chat_completion_request, MistralTokenizer,
_prepare_apply_chat_template_tools_and_messages,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"openai_request,expected_mistral_request", "openai_request,expected_mistral_output",
[ [
( (
{ {
...@@ -41,19 +34,22 @@ from vllm.transformers_utils.tokenizers.mistral import ( ...@@ -41,19 +34,22 @@ from vllm.transformers_utils.tokenizers.mistral import (
} }
], ],
}, },
ChatCompletionRequest( (
messages=[ [
UserMessage(content="What is the current local date and time?") {
"role": "user",
"content": "What is the current local date and time?",
}
], ],
tools=[ [
Tool( {
type="function", "type": "function",
function=Function( "function": {
name="get_current_time", "description": "Fetch the current local date and time.",
description="Fetch the current local date and time.", "name": "get_current_time",
parameters={}, "parameters": {},
), },
) }
], ],
), ),
), ),
...@@ -71,39 +67,44 @@ from vllm.transformers_utils.tokenizers.mistral import ( ...@@ -71,39 +67,44 @@ from vllm.transformers_utils.tokenizers.mistral import (
"function": { "function": {
"description": "Fetch the current local date and time.", "description": "Fetch the current local date and time.",
"name": "get_current_time", "name": "get_current_time",
"parameters": None, "parameters": {},
}, },
} }
], ],
}, },
ChatCompletionRequest( (
messages=[ [
UserMessage(content="What is the current local date and time?") {
"role": "user",
"content": "What is the current local date and time?",
}
], ],
tools=[ [
Tool( {
type="function", "type": "function",
function=Function( "function": {
name="get_current_time", "description": "Fetch the current local date and time.",
description="Fetch the current local date and time.", "name": "get_current_time",
parameters={}, "parameters": {},
), },
) }
], ],
), ),
), ),
], ],
) )
def test_make_mistral_chat_completion_request(openai_request, expected_mistral_request): def test_prepare_apply_chat_template_tools_and_messages(
actual_request = make_mistral_chat_completion_request( openai_request, expected_mistral_output
):
actual_request = _prepare_apply_chat_template_tools_and_messages(
openai_request["messages"], openai_request["tools"] openai_request["messages"], openai_request["tools"]
) )
assert actual_request == expected_mistral_request assert actual_request == expected_mistral_output
# Tool use with list content and reasoning_content # Tool use with list content and reasoning_content
@pytest.mark.parametrize( @pytest.mark.parametrize(
"openai_request,expected_mistral_request", "openai_request,expected_mistral_output",
[ [
( (
{ {
...@@ -154,34 +155,517 @@ def test_make_mistral_chat_completion_request(openai_request, expected_mistral_r ...@@ -154,34 +155,517 @@ def test_make_mistral_chat_completion_request(openai_request, expected_mistral_r
} }
], ],
}, },
ChatCompletionRequest( (
messages=[ [
UserMessage(content="What's the weather in Paris?"), {
AssistantMessage( "role": "user",
content=None, "content": "What's the weather in Paris?",
tool_calls=[ },
ToolCall( {
id="call123", "role": "assistant",
function=FunctionCall( "content": None,
name="get_weather", "tool_calls": [
arguments='{"city": "Paris"}', {
"id": "call123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"content": [{"type": "text", "text": "Rainy"}],
"name": "get_weather",
"tool_call_id": "call123",
},
],
[
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name",
}
},
"required": ["city"],
},
},
}
],
), ),
) )
], ],
)
def test_prepare_apply_chat_template_tools_and_messages_list_content(
openai_request, expected_mistral_output
):
actual_request = _prepare_apply_chat_template_tools_and_messages(
openai_request["messages"], openai_request["tools"]
)
assert actual_request == expected_mistral_output
def test_prepare_apply_chat_template_generation_prompt_and_continue():
messages = [{"role": "assistant", "content": "Hello"}]
tools: list[dict[str, Any]] = []
with pytest.raises(ValueError):
_prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=True
)
messages = [{"role": "user", "content": "Hello"}]
out_messages, _ = _prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=True
)
assert out_messages == [{"role": "user", "content": "Hello"}]
with pytest.raises(ValueError):
_prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=True, continue_final_message=True
)
messages = [{"role": "assistant", "content": "Hello"}]
out_messages, _ = _prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=False, continue_final_message=True
)
assert out_messages == [{"role": "assistant", "content": "Hello"}]
messages = [{"role": "user", "content": "Hello"}]
with pytest.raises(ValueError):
_prepare_apply_chat_template_tools_and_messages(
messages, tools, add_generation_prompt=False, continue_final_message=True
)
@pytest.fixture(scope="module")
def mistral_tokenizer(request) -> MistralTokenizer:
return MistralTokenizer.from_pretrained(request.param)
@pytest.mark.parametrize(
"mistral_tokenizer",
["mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Magistral-Small-2509"],
indirect=True,
)
class TestMistralTokenizer:
def test_all_special_tokens(self, mistral_tokenizer: MistralTokenizer):
attributes = [
mistral_tokenizer.all_special_tokens,
mistral_tokenizer.all_special_tokens_extended,
]
for attribute in attributes:
if mistral_tokenizer.is_tekken:
assert attribute == [
"<unk>",
"<s>",
"</s>",
"[INST]",
"[/INST]",
"[AVAILABLE_TOOLS]",
"[/AVAILABLE_TOOLS]",
"[TOOL_RESULTS]",
"[/TOOL_RESULTS]",
"[TOOL_CALLS]",
"[IMG]",
"<pad>",
"[IMG_BREAK]",
"[IMG_END]",
"[PREFIX]",
"[MIDDLE]",
"[SUFFIX]",
"[SYSTEM_PROMPT]",
"[/SYSTEM_PROMPT]",
"[TOOL_CONTENT]",
] + [f"<SPECIAL_{i}>" for i in range(20, 32)] + [
"[ARGS]",
"[CALL_ID]",
"[THINK]",
"[/THINK]",
] + [f"<SPECIAL_{i}>" for i in range(36, 1000)]
else:
assert attribute == [
"<s>",
"</s>",
"[INST]",
"[/INST]",
"[TOOL_CALLS]",
"[AVAILABLE_TOOLS]",
"[/AVAILABLE_TOOLS]",
"[TOOL_RESULTS]",
"[/TOOL_RESULTS]",
] + [f"[control_{i}]" for i in range(8, 769)]
def get_vocab(self, mistral_tokenizer: MistralTokenizer):
assert (
mistral_tokenizer.get_vocab()
== mistral_tokenizer.transformers_tokenizer.get_vocab()
)
def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer):
assert mistral_tokenizer.get_added_vocab() == {}
def test_encode_one(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[22177, 4304, 2662] if mistral_tokenizer.is_tekken else [23325, 2294, 1686]
)
assert mistral_tokenizer.encode_one("Hello world !") == token_ids
assert mistral_tokenizer.encode_one("Hello world !", max_length=1) == token_ids
assert (
mistral_tokenizer.encode_one("Hello world !", truncation=True, max_length=1)
== token_ids[:-2]
)
assert (
mistral_tokenizer.encode_one(
"Hello world !", truncation=False, max_length=1
)
== token_ids
)
def test_encode(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[1, 22177, 4304, 2662, 2]
if mistral_tokenizer.is_tekken
else [1, 23325, 2294, 1686, 2]
)
assert mistral_tokenizer.encode("Hello world !") == token_ids[:-1]
assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-2]
assert (
mistral_tokenizer.encode("Hello world !", truncation=True, max_length=3)
== token_ids[:-2]
)
assert (
mistral_tokenizer.encode("Hello world !", truncation=False, max_length=3)
== token_ids[:-1]
)
assert (
mistral_tokenizer.encode("Hello world !", add_special_tokens=True)
== token_ids
)
assert (
mistral_tokenizer.encode(
"Hello world !", add_special_tokens=True, max_length=3
)
== token_ids[:-2]
)
assert (
mistral_tokenizer.encode(
"Hello world !", add_special_tokens=True, truncation=False, max_length=3
)
== token_ids
)
assert (
mistral_tokenizer.encode("Hello world !", add_special_tokens=False)
== token_ids[1:-1]
)
@pytest.mark.parametrize(
"openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output",
[
(
{
"messages": [
{
"role": "user",
"content": "Hello world !",
}
],
},
True,
False,
([1, 3, 23325, 2294, 1686, 4], [1, 3, 22177, 4304, 2662, 4]),
("<s>[INST]▁Hello▁world▁![/INST]", ("<s>[INST]Hello world ![/INST]")),
),
(
{
"messages": [
{
"role": "system",
"content": "I am an AI",
},
{
"role": "user",
"content": "Hello world !",
},
],
},
True,
False,
(
[1, 3, 1083, 1605, 1164, 16875, 781, 781, 16998, 2294, 1686, 4],
[1, 17, 1073, 1855, 1420, 26554, 18, 3, 22177, 4304, 2662, 4],
),
(
"<s>[INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]",
(
"<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][INST]Hello world ![/INST]" # noqa: E501
),
),
),
(
{
"messages": [
{
"role": "system",
"content": "I am an AI",
},
{
"role": "user",
"content": "Hello world !",
},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name",
}
},
"required": ["city"],
},
},
}
],
},
True,
False,
(
[
1,
6,
1501,
7567,
1891,
2032,
1113,
3396,
1316,
1113,
3396,
2032,
10598,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
7286,
2032,
1113,
2226,
29481,
1040,
2636,
8854,
1065,
1032,
3758,
9959,
1113,
12206,
2032,
10598,
1891,
2032,
1113,
3582,
1316,
1113,
11491,
2032,
10598,
19141,
2032,
10598,
1891,
2032,
1113,
2195,
1316,
1113,
7286,
2032,
1113,
1782,
3758,
1909,
29507,
11549,
1113,
11661,
2032,
8135,
19141,
3010,
1743,
10925,
7,
3,
1083,
1605,
1164,
16875,
781,
781,
16998,
2294,
1686,
4,
],
[
1,
17,
1073,
1855,
1420,
26554,
18,
5,
1091,
19227,
4994,
2811,
1429,
5165,
1897,
1429,
5165,
2811,
16753,
2391,
2811,
1429,
1689,
1095,
45629,
1897,
1429,
14653,
2811,
1429,
1071,
3083,
1278,
3519,
17253,
1294,
1261,
5970,
39249,
1429,
26204,
2811,
16753,
4994,
2811,
1429,
6371,
1897,
1429,
48649,
2811,
16753,
29363,
2811,
16753,
4994,
2811,
1429,
3607,
1897,
1429,
14653,
2811,
1429,
1784,
5970,
2564,
1034,
47579,
1429,
15760,
2811,
12161,
29363,
4964,
2821,
27028,
6,
3,
22177,
4304,
2662,
4,
],
),
(
'<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]',
(
'<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST]' # noqa: E501
),
), ),
ToolMessage(
content="Rainy",
tool_call_id="call123",
name="get_weather",
), ),
(
{
"messages": [
{
"role": "system",
"content": "I am an AI",
},
{
"role": "user",
"content": "Hello world !",
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "123456789",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "123456789",
"content": '{"temperature": 20, "unit": "celsius"}',
},
], ],
tools=[ "tools": [
Tool( {
type="function", "type": "function",
function=Function( "function": {
name="get_weather", "name": "get_weather",
description="Gets the current weather in a city.", "description": "Gets the current weather in a city.",
parameters={ "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"city": { "city": {
...@@ -191,17 +675,1535 @@ def test_make_mistral_chat_completion_request(openai_request, expected_mistral_r ...@@ -191,17 +675,1535 @@ def test_make_mistral_chat_completion_request(openai_request, expected_mistral_r
}, },
"required": ["city"], "required": ["city"],
}, },
},
}
],
},
True,
False,
(
[
1,
6,
1501,
7567,
1891,
2032,
1113,
3396,
1316,
1113,
3396,
2032,
10598,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
7286,
2032,
1113,
2226,
29481,
1040,
2636,
8854,
1065,
1032,
3758,
9959,
1113,
12206,
2032,
10598,
1891,
2032,
1113,
3582,
1316,
1113,
11491,
2032,
10598,
19141,
2032,
10598,
1891,
2032,
1113,
2195,
1316,
1113,
7286,
2032,
1113,
1782,
3758,
1909,
29507,
11549,
1113,
11661,
2032,
8135,
19141,
3010,
1743,
10925,
7,
3,
1083,
1605,
1164,
16875,
781,
781,
16998,
2294,
1686,
4,
5,
1501,
7567,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
17452,
2032,
10598,
19141,
2032,
1113,
4684,
1046,
8474,
1113,
1081,
2032,
1113,
29508,
29518,
29538,
29549,
29550,
29552,
29555,
29551,
29542,
29507,
10925,
2,
8,
10598,
4557,
2032,
10598,
29475,
17329,
2032,
29473,
29518,
29502,
29493,
1113,
6074,
2032,
1113,
29485,
1958,
3938,
8474,
1113,
3613,
29498,
1081,
2032,
1113,
29508,
29518,
29538,
29549,
29550,
29552,
29555,
29551,
29542,
18163,
9,
],
[
1,
17,
1073,
1855,
1420,
26554,
18,
5,
1091,
19227,
4994,
2811,
1429,
5165,
1897,
1429,
5165,
2811,
16753,
2391,
2811,
1429,
1689,
1095,
45629,
1897,
1429,
14653,
2811,
1429,
1071,
3083,
1278,
3519,
17253,
1294,
1261,
5970,
39249,
1429,
26204,
2811,
16753,
4994,
2811,
1429,
6371,
1897,
1429,
48649,
2811,
16753,
29363,
2811,
16753,
4994,
2811,
1429,
3607,
1897,
1429,
14653,
2811,
1429,
1784,
5970,
2564,
1034,
47579,
1429,
15760,
2811,
12161,
29363,
4964,
2821,
27028,
6,
3,
22177,
4304,
2662,
4,
9,
1689,
1095,
45629,
32,
19227,
29363,
2811,
1429,
42572,
46005,
2,
7,
19227,
113824,
2811,
1032,
1050,
1048,
1044,
1429,
8979,
2811,
1429,
1099,
79092,
46005,
8,
],
),
(
'<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST][TOOL_CALLS]▁[{"name":▁"get_weather",▁"arguments":▁{"city":▁"Paris"},▁"id":▁"123456789"}]</s>[TOOL_RESULTS]▁{"content":▁{"temperature":▁20,▁"unit":▁"celsius"},▁"call_id":▁"123456789"}[/TOOL_RESULTS]',
(
'<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST][TOOL_CALLS]get_weather[ARGS]{"city": "Paris"}</s>[TOOL_RESULTS]{"temperature": 20, "unit": "celsius"}[/TOOL_RESULTS]' # noqa: E501
),
),
),
(
{
"messages": [
{
"role": "user",
"content": "Hello world !",
},
{
"role": "assistant",
"content": "Hello ",
},
],
},
False,
True,
(
[1, 3, 23325, 2294, 1686, 4, 23325],
[1, 3, 22177, 4304, 2662, 4, 22177, 2],
),
(
"<s>[INST]▁Hello▁world▁![/INST]▁Hello",
("<s>[INST]Hello world ![/INST]Hello</s>"),
),
),
],
)
def test_apply_chat_template(
self,
mistral_tokenizer: MistralTokenizer,
openai_request: dict[str, Any],
add_generation_prompt: bool,
continue_final_message: bool,
expected_output: tuple[list[int], list[int]],
decoded_expected_output: tuple[str, str],
):
actual_output = mistral_tokenizer.apply_chat_template(
openai_request["messages"],
tools=openai_request.get("tools", []),
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
decoded_actual_output = mistral_tokenizer.tokenizer.decode(
actual_output, SpecialTokenPolicy.KEEP
)
assert actual_output == expected_output[mistral_tokenizer.is_tekken]
assert (
decoded_actual_output
== decoded_expected_output[mistral_tokenizer.is_tekken]
)
def test_apply_chat_template_error(self, mistral_tokenizer: MistralTokenizer):
messages = [{"role": "user", "content": "Hello world !"}]
with pytest.raises(ValueError):
mistral_tokenizer.apply_chat_template(
messages,
tools=[],
add_generation_prompt=True,
continue_final_message=True,
)
with pytest.raises(ValueError):
mistral_tokenizer.apply_chat_template(
messages,
tools=[],
add_generation_prompt=False,
continue_final_message=True,
)
messages = [
{"role": "user", "content": "Hello world !"},
{"role": "assistant", "content": "Hello "},
]
with pytest.raises(ValueError):
mistral_tokenizer.apply_chat_template(
messages,
tools=[],
add_generation_prompt=True,
continue_final_message=False,
)
messages = [
{"role": "user", "content": "Hello world !"},
{"role": "assistant", "content": "Hello "},
]
with pytest.raises(InvalidMessageStructureException):
mistral_tokenizer.apply_chat_template(
messages,
tools=[],
add_generation_prompt=False,
continue_final_message=False,
)
@pytest.mark.parametrize(
"skip_special_tokens,expected_tokens",
(
(
False,
(
"<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>",
"<s>[INST]Hello world ![/INST]Hello</s>",
),
),
(True, ("Hello world ! Hello", "Hello world !Hello")),
), ),
) )
def test_decode(
self,
mistral_tokenizer: MistralTokenizer,
skip_special_tokens: bool,
expected_tokens: tuple[str, str],
):
ids = (
[1, 3, 23325, 2294, 1686, 4, 23325, 2],
[1, 3, 22177, 4304, 2662, 4, 22177, 2],
)
assert (
mistral_tokenizer.decode(
ids[mistral_tokenizer.is_tekken],
skip_special_tokens=skip_special_tokens,
)
== expected_tokens[mistral_tokenizer.is_tekken]
)
def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
tokens = (
[
"<s>",
"[AVAILABLE_TOOLS]",
"▁[",
'{"',
"type",
'":',
'▁"',
"function",
'",',
'▁"',
"function",
'":',
'▁{"',
"name",
'":',
'▁"',
"get",
"_",
"we",
"ather",
'",',
'▁"',
"description",
'":',
'▁"',
"Get",
"s",
"▁the",
"▁current",
"▁weather",
"▁in",
"▁a",
"▁city",
'.",',
'▁"',
"parameters",
'":',
'▁{"',
"type",
'":',
'▁"',
"object",
'",',
'▁"',
"properties",
'":',
'▁{"',
"city",
'":',
'▁{"',
"type",
'":',
'▁"',
"string",
'",',
'▁"',
"description",
'":',
'▁"',
"The",
"▁city",
"▁name",
'"',
"}},",
'▁"',
"required",
'":',
'▁["',
"city",
'"]',
"}}",
"}]",
"[/AVAILABLE_TOOLS]",
"[INST]",
"▁I",
"▁am",
"▁an",
"▁AI",
"<0x0A>",
"<0x0A>",
"Hello",
"▁world",
"▁!",
"[/INST]",
"[TOOL_CALLS]",
"▁[",
'{"',
"name",
'":',
'▁"',
"get",
"_",
"we",
"ather",
'",',
'▁"',
"arguments",
'":',
'▁{"',
"city",
'":',
'▁"',
"Par",
"is",
'"},',
'▁"',
"id",
'":',
'▁"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"',
"}]",
"</s>",
"[TOOL_RESULTS]",
'▁{"',
"content",
'":',
'▁{"',
"t",
"emperature",
'":',
"▁",
"2",
"0",
",",
'▁"',
"unit",
'":',
'▁"',
"c",
"els",
"ius",
'"},',
'▁"',
"call",
"_",
"id",
'":',
'▁"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"}',
"[/TOOL_RESULTS]",
],
[
"<s>",
"[SYSTEM_PROMPT]",
"I",
" am",
" an",
" AI",
"[/SYSTEM_PROMPT]",
"[AVAILABLE_TOOLS]",
"[",
'{"',
"type",
'":',
' "',
"function",
'",',
' "',
"function",
'":',
' {"',
"name",
'":',
' "',
"get",
"_",
"weather",
'",',
' "',
"description",
'":',
' "',
"G",
"ets",
" the",
" current",
" weather",
" in",
" a",
" city",
'.",',
' "',
"parameters",
'":',
' {"',
"type",
'":',
' "',
"object",
'",',
' "',
"properties",
'":',
' {"',
"city",
'":',
' {"',
"type",
'":',
' "',
"string",
'",',
' "',
"description",
'":',
' "',
"The",
" city",
" name",
'"',
"}},",
' "',
"required",
'":',
' ["',
"city",
'"]',
"}}",
"}]",
"[/AVAILABLE_TOOLS]",
"[INST]",
"Hello",
" world",
" !",
"[/INST]",
"[TOOL_CALLS]",
"get",
"_",
"weather",
"[ARGS]",
'{"',
"city",
'":',
' "',
"Paris",
'"}',
"</s>",
"[TOOL_RESULTS]",
'{"',
"temperature",
'":',
" ",
"2",
"0",
",",
' "',
"unit",
'":',
' "',
"c",
"elsius",
'"}',
"[/TOOL_RESULTS]",
],
)
expected_strings = (
'[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}] I am an AI\n\nHello world ![TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}, "id": "123456789"}] {"content": {"temperature": 20, "unit": "celsius"}, "call_id": "123456789"}', # noqa: E501
'I am an AI[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}]Hello world ![TOOL_CALLS]get_weather{"city": "Paris"}{"temperature": 20, "unit": "celsius"}', # noqa: E501
)
assert (
mistral_tokenizer.convert_tokens_to_string(
tokens[mistral_tokenizer.is_tekken]
)
== expected_strings[mistral_tokenizer.is_tekken]
)
@pytest.mark.parametrize(
"skip_special_tokens,tuple_expected_tokens",
(
(
True,
(
[
"▁[",
'{"',
"type",
'":',
'▁"',
"function",
'",',
'▁"',
"function",
'":',
'▁{"',
"name",
'":',
'▁"',
"get",
"_",
"we",
"ather",
'",',
'▁"',
"description",
'":',
'▁"',
"Get",
"s",
"▁the",
"▁current",
"▁weather",
"▁in",
"▁a",
"▁city",
'.",',
'▁"',
"parameters",
'":',
'▁{"',
"type",
'":',
'▁"',
"object",
'",',
'▁"',
"properties",
'":',
'▁{"',
"city",
'":',
'▁{"',
"type",
'":',
'▁"',
"string",
'",',
'▁"',
"description",
'":',
'▁"',
"The",
"▁city",
"▁name",
'"',
"}},",
'▁"',
"required",
'":',
'▁["',
"city",
'"]',
"}}",
"}]",
"▁I",
"▁am",
"▁an",
"▁AI",
"<0x0A>",
"<0x0A>",
"Hello",
"▁world",
"▁!",
"[TOOL_CALLS]",
"▁[",
'{"',
"name",
'":',
'▁"',
"get",
"_",
"we",
"ather",
'",',
'▁"',
"arguments",
'":',
'▁{"',
"city",
'":',
'▁"',
"Par",
"is",
'"},',
'▁"',
"id",
'":',
'▁"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"',
"}]",
'▁{"',
"content",
'":',
'▁{"',
"t",
"emperature",
'":',
"▁",
"2",
"0",
",",
'▁"',
"unit",
'":',
'▁"',
"c",
"els",
"ius",
'"},',
'▁"',
"call",
"_",
"id",
'":',
'▁"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"}',
], ],
[
"I",
" am",
" an",
" AI",
"[",
'{"',
"type",
'":',
' "',
"function",
'",',
' "',
"function",
'":',
' {"',
"name",
'":',
' "',
"get",
"_",
"weather",
'",',
' "',
"description",
'":',
' "',
"G",
"ets",
" the",
" current",
" weather",
" in",
" a",
" city",
'.",',
' "',
"parameters",
'":',
' {"',
"type",
'":',
' "',
"object",
'",',
' "',
"properties",
'":',
' {"',
"city",
'":',
' {"',
"type",
'":',
' "',
"string",
'",',
' "',
"description",
'":',
' "',
"The",
" city",
" name",
'"',
"}},",
' "',
"required",
'":',
' ["',
"city",
'"]',
"}}",
"}]",
"Hello",
" world",
" !",
"[TOOL_CALLS]",
"get",
"_",
"weather",
'{"',
"city",
'":',
' "',
"Paris",
'"}',
'{"',
"temperature",
'":',
" ",
"2",
"0",
",",
' "',
"unit",
'":',
' "',
"c",
"elsius",
'"}',
],
),
),
(
False,
(
[
"<s>",
"[AVAILABLE_TOOLS]",
"▁[",
'{"',
"type",
'":',
'▁"',
"function",
'",',
'▁"',
"function",
'":',
'▁{"',
"name",
'":',
'▁"',
"get",
"_",
"we",
"ather",
'",',
'▁"',
"description",
'":',
'▁"',
"Get",
"s",
"▁the",
"▁current",
"▁weather",
"▁in",
"▁a",
"▁city",
'.",',
'▁"',
"parameters",
'":',
'▁{"',
"type",
'":',
'▁"',
"object",
'",',
'▁"',
"properties",
'":',
'▁{"',
"city",
'":',
'▁{"',
"type",
'":',
'▁"',
"string",
'",',
'▁"',
"description",
'":',
'▁"',
"The",
"▁city",
"▁name",
'"',
"}},",
'▁"',
"required",
'":',
'▁["',
"city",
'"]',
"}}",
"}]",
"[/AVAILABLE_TOOLS]",
"[INST]",
"▁I",
"▁am",
"▁an",
"▁AI",
"<0x0A>",
"<0x0A>",
"Hello",
"▁world",
"▁!",
"[/INST]",
"[TOOL_CALLS]",
"▁[",
'{"',
"name",
'":',
'▁"',
"get",
"_",
"we",
"ather",
'",',
'▁"',
"arguments",
'":',
'▁{"',
"city",
'":',
'▁"',
"Par",
"is",
'"},',
'▁"',
"id",
'":',
'▁"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"',
"}]",
"</s>",
"[TOOL_RESULTS]",
'▁{"',
"content",
'":',
'▁{"',
"t",
"emperature",
'":',
"▁",
"2",
"0",
",",
'▁"',
"unit",
'":',
'▁"',
"c",
"els",
"ius",
'"},',
'▁"',
"call",
"_",
"id",
'":',
'▁"',
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'"}',
"[/TOOL_RESULTS]",
],
[
"<s>",
"[SYSTEM_PROMPT]",
"I",
" am",
" an",
" AI",
"[/SYSTEM_PROMPT]",
"[AVAILABLE_TOOLS]",
"[",
'{"',
"type",
'":',
' "',
"function",
'",',
' "',
"function",
'":',
' {"',
"name",
'":',
' "',
"get",
"_",
"weather",
'",',
' "',
"description",
'":',
' "',
"G",
"ets",
" the",
" current",
" weather",
" in",
" a",
" city",
'.",',
' "',
"parameters",
'":',
' {"',
"type",
'":',
' "',
"object",
'",',
' "',
"properties",
'":',
' {"',
"city",
'":',
' {"',
"type",
'":',
' "',
"string",
'",',
' "',
"description",
'":',
' "',
"The",
" city",
" name",
'"',
"}},",
' "',
"required",
'":',
' ["',
"city",
'"]',
"}}",
"}]",
"[/AVAILABLE_TOOLS]",
"[INST]",
"Hello",
" world",
" !",
"[/INST]",
"[TOOL_CALLS]",
"get",
"_",
"weather",
"[ARGS]",
'{"',
"city",
'":',
' "',
"Paris",
'"}',
"</s>",
"[TOOL_RESULTS]",
'{"',
"temperature",
'":',
" ",
"2",
"0",
",",
' "',
"unit",
'":',
' "',
"c",
"elsius",
'"}',
"[/TOOL_RESULTS]",
],
),
),
), ),
) )
def test_convert_ids_to_tokens(
self,
mistral_tokenizer: MistralTokenizer,
skip_special_tokens: bool,
tuple_expected_tokens: tuple[list[str], list[str]],
):
tuple_ids = (
[
1,
6,
1501,
7567,
1891,
2032,
1113,
3396,
1316,
1113,
3396,
2032,
10598,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
7286,
2032,
1113,
2226,
29481,
1040,
2636,
8854,
1065,
1032,
3758,
9959,
1113,
12206,
2032,
10598,
1891,
2032,
1113,
3582,
1316,
1113,
11491,
2032,
10598,
19141,
2032,
10598,
1891,
2032,
1113,
2195,
1316,
1113,
7286,
2032,
1113,
1782,
3758,
1909,
29507,
11549,
1113,
11661,
2032,
8135,
19141,
3010,
1743,
10925,
7,
3,
1083,
1605,
1164,
16875,
781,
781,
16998,
2294,
1686,
4,
5,
1501,
7567,
1629,
2032,
1113,
1295,
29498,
1537,
1991,
1316,
1113,
17452,
2032,
10598,
19141,
2032,
1113,
4684,
1046,
8474,
1113,
1081,
2032,
1113,
29508,
29518,
29538,
29549,
29550,
29552,
29555,
29551,
29542,
29507,
10925,
2,
8,
10598,
4557,
2032,
10598,
29475,
17329,
2032,
29473,
29518,
29502,
29493,
1113,
6074,
2032,
1113,
29485,
1958,
3938,
8474,
1113,
3613,
29498,
1081,
2032,
1113,
29508,
29518,
29538,
29549,
29550,
29552,
29555,
29551,
29542,
18163,
9,
], ],
) [
def test_make_mistral_chat_completion_request_list_content( 1,
openai_request, expected_mistral_request 17,
): 1073,
actual_request = make_mistral_chat_completion_request( 1855,
openai_request["messages"], openai_request["tools"] 1420,
26554,
18,
5,
1091,
19227,
4994,
2811,
1429,
5165,
1897,
1429,
5165,
2811,
16753,
2391,
2811,
1429,
1689,
1095,
45629,
1897,
1429,
14653,
2811,
1429,
1071,
3083,
1278,
3519,
17253,
1294,
1261,
5970,
39249,
1429,
26204,
2811,
16753,
4994,
2811,
1429,
6371,
1897,
1429,
48649,
2811,
16753,
29363,
2811,
16753,
4994,
2811,
1429,
3607,
1897,
1429,
14653,
2811,
1429,
1784,
5970,
2564,
1034,
47579,
1429,
15760,
2811,
12161,
29363,
4964,
2821,
27028,
6,
3,
22177,
4304,
2662,
4,
9,
1689,
1095,
45629,
32,
19227,
29363,
2811,
1429,
42572,
46005,
2,
7,
19227,
113824,
2811,
1032,
1050,
1048,
1044,
1429,
8979,
2811,
1429,
1099,
79092,
46005,
8,
],
)
ids = tuple_ids[mistral_tokenizer.is_tekken]
expected_tokens = tuple_expected_tokens[mistral_tokenizer.is_tekken]
actual_tokens = mistral_tokenizer.convert_ids_to_tokens(
ids, skip_special_tokens=skip_special_tokens
) )
assert actual_request == expected_mistral_request assert actual_tokens == expected_tokens
...@@ -403,20 +403,12 @@ def resolve_mistral_chat_template( ...@@ -403,20 +403,12 @@ def resolve_mistral_chat_template(
chat_template: Optional[str], chat_template: Optional[str],
**kwargs: Any, **kwargs: Any,
) -> Optional[str]: ) -> Optional[str]:
if chat_template is not None: if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
logger.warning_once( raise ValueError(
"'chat_template' cannot be overridden for mistral tokenizer." "'chat_template' or 'chat_template_kwargs' cannot be overridden "
) "for mistral tokenizer."
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
) )
return None return None
......
...@@ -10,7 +10,8 @@ from typing import Annotated, Literal, Optional, Union ...@@ -10,7 +10,8 @@ from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image from PIL import Image
......
...@@ -12,12 +12,8 @@ import regex as re ...@@ -12,12 +12,8 @@ import regex as re
import torch import torch
import torch.nn as nn import torch.nn as nn
from mistral_common.audio import mel_filter_bank from mistral_common.audio import mel_filter_bank
from mistral_common.protocol.instruct.messages import ( from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
AudioChunk, from mistral_common.protocol.instruct.messages import UserMessage
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
import huggingface_hub
import regex as re
from huggingface_hub import HfApi, hf_hub_download
from transformers.tokenization_utils_base import BatchEncoding
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_base import TokenizerBase from vllm.transformers_utils.tokenizer_base import TokenizerBase
from vllm.utils import is_list_of
if TYPE_CHECKING: if TYPE_CHECKING:
# make sure `mistral_common` is lazy imported, from mistral_common.protocol.instruct.request import (
# so that users who only use non-mistral models ChatCompletionRequest as MistralChatCompletionRequest,
# will not be bothered by the dependency. )
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.mistral import ( from transformers.tokenization_mistral_common import (
MistralTokenizer as PublicMistralTokenizer, MistralCommonTokenizer as TransformersMistralTokenizer,
) )
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
logger = init_logger(__name__) logger = init_logger(__name__)
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
# SEE: https://github.com/vllm-project/vllm/pull/9951 # SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes # Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes # NOTE: There is currently a bug in pydantic where attributes
...@@ -65,7 +58,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): ...@@ -65,7 +58,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
request.messages[i]["tool_calls"] = validated_tool_calls request.messages[i]["tool_calls"] = validated_tool_calls
def truncate_tool_call_ids(request: "ChatCompletionRequest"): def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
"""Truncates tool call IDs for Mistral's ID requirements.""" """Truncates tool call IDs for Mistral's ID requirements."""
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if message.get("role") == "assistant": if message.get("role") == "assistant":
...@@ -95,85 +88,35 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"): ...@@ -95,85 +88,35 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"):
request.messages[i]["tool_call_id"] = tool_call_id request.messages[i]["tool_call_id"] = tool_call_id
def validate_request_params(request: "ChatCompletionRequest"): def _prepare_apply_chat_template_tools_and_messages(
if request.skip_special_tokens is not None and not request.skip_special_tokens: messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None,
continue_final_message: bool = False,
add_generation_prompt: bool = False,
) -> tuple[list["ChatCompletionMessageParam"], Optional[list[dict[str, Any]]]]:
if add_generation_prompt and continue_final_message:
raise ValueError( raise ValueError(
"skip_special_tokens=False is not supported for Mistral tokenizers." "Cannot set both `add_generation_prompt` and "
) "`continue_final_message` to True."
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *repo_id.split("/")]
),
)
if revision is None:
revision_file = os.path.join(repo_cache, "refs", "main")
if os.path.isfile(revision_file):
with open(revision_file) as file:
revision = file.read()
if revision:
revision_dir = os.path.join(repo_cache, "snapshots", revision)
if os.path.isdir(revision_dir):
return os.listdir(revision_dir)
return []
def find_tokenizer_file(files: list[str]):
# Accept both versioned (tokenizer.model.v3) and unversioned
# (tokenizer.model) forms, plus tekken.json and tokenizer.mm.model
# variants. Previous pattern only matched the versioned variants.
file_pattern = re.compile(
r"^tokenizer\.model(\.v.*)?|tekken\.json|tokenizer\.mm\.model(\.v.*)?$"
) )
matched_files = [file for file in files if file_pattern.match(file)] last_message = cast(dict[str, Any], messages[-1])
if len(matched_files) > 1: # add_generation_prompt is directly handled by the tokenizer but we
logger.warning( # check if the user is trying to use it with a final assistant message
"Multiple files matched pattern `%s`: %s. Using %s.", # which is probably not what they want.
file_pattern.pattern, # If add_generation_prompt is False, we don't need to check anything.
matched_files, if add_generation_prompt and last_message["role"] == "assistant":
matched_files[0], raise ValueError(
"Cannot set `add_generation_prompt` to True when "
"the last message is from the assistant. Consider "
"using `continue_final_message` instead."
) )
elif len(matched_files) == 0: if continue_final_message and last_message["role"] != "assistant":
raise OSError( raise ValueError(
f"Found {len(matched_files)} files matching the " "Cannot set `continue_final_message` to True when "
f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " "the last message is not from the assistant."
f"tokenizer is present in {files}."
) )
return matched_files[0]
def _aggregate_content(content: list) -> list[dict[str, Any]]:
aggregated_content: list[dict[str, Any]] = []
for chunk in content:
if (
chunk.get("type") == "text"
and aggregated_content
and aggregated_content[-1].get("type") == "text"
):
aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
else:
aggregated_content.append(chunk)
if len(aggregated_content) == 1 and aggregated_content[0].get("type") == "text":
content = aggregated_content[0]["text"]
return content
def make_mistral_chat_completion_request(
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None,
) -> "ChatCompletionRequest":
last_message = cast(dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True
# mistral-common requires AssistantMessage content to be string [1]. # mistral-common requires AssistantMessage content to be string [1].
# #
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
...@@ -181,13 +124,6 @@ def make_mistral_chat_completion_request( ...@@ -181,13 +124,6 @@ def make_mistral_chat_completion_request(
# Remove reasoning_content as unsupported by Mistral # Remove reasoning_content as unsupported by Mistral
_ = message.pop("reasoning_content", None) # type: ignore _ = message.pop("reasoning_content", None) # type: ignore
# Convert list text content to string
if message.get("role") in ("assistant", "tool"):
content: Any = message.get("content")
if isinstance(content, list):
content = _aggregate_content(content)
message["content"] = content
# The Mistral client, in comparison to the OpenAI client, requires the # The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict and the "description" string to be present # "parameters" dict and the "description" string to be present
# even if they are empty. # even if they are empty.
...@@ -200,108 +136,113 @@ def make_mistral_chat_completion_request( ...@@ -200,108 +136,113 @@ def make_mistral_chat_completion_request(
if function.get("description") is None: if function.get("description") is None:
function["description"] = "" function["description"] = ""
from mistral_common.protocol.instruct.request import ChatCompletionRequest return messages, tools
return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var]
def validate_request_params(request: "ChatCompletionRequest"):
if request.chat_template is not None or request.chat_template_kwargs is not None:
raise ValueError("chat_template is not supported for Mistral tokenizers.")
class MistralTokenizer(TokenizerBase):
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
_mistral_version_str = self.instruct.tokenizer.version.value
self.version: int = int(_mistral_version_str.split("v")[-1])
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer def _tekken_token_to_id(tokenizer: "Tekkenizer", t: Union[str, bytes]) -> int:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from mistral_common.tokens.tokenizers.tekken import Tekkenizer from mistral_common.tokens.tokenizers.tekken import Tekkenizer
self.is_tekken = isinstance(tokenizer_, Tekkenizer) assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
shift = tokenizer.num_special_tokens
try:
return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
except KeyError:
t_str = t_bytes.decode("utf-8")
if t_str in tokenizer._special_tokens_reverse_vocab:
return tokenizer._special_tokens_reverse_vocab[t_str]
logger.warning(
"Failed to convert token %s to id, replacing with <unk>", t_bytes
)
return tokenizer.unk_id
class MistralTokenizer(TokenizerBase):
def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
from mistral_common.tokens.tokenizers.sentencepiece import ( from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer, SentencePieceTokenizer,
) )
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) self.transformers_tokenizer = tokenizer
self._special_token_policy = ( self.mistral = tokenizer.tokenizer
SpecialTokenPolicy.IGNORE if self.is_tekken else None self.instruct = self.mistral.instruct_tokenizer
) self.tokenizer = self.instruct.tokenizer
_mistral_version_str = str(self.tokenizer.version.value)
self.version: int = int(_mistral_version_str.split("v")[-1])
self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
if not (self.is_tekken or self.is_spm): if not (self.is_tekken or self.is_spm):
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")
self._vocab = tokenizer_.vocab() # Reverse order to ensure that the lowest token id is kept.
# Convert to a dict[str, int] to match protocol, but this is a lossy self._vocab_dict = {
# conversion. There may be multiple token ids that decode to the same self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
# string due to partial UTF-8 byte sequences being converted to � for i in range(self.vocab_size - 1, -1, -1)
self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)} }
self.tokenizer = tokenizer_ # Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
# Vocab sorted by token id.
self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1 self._max_token_id = self.vocab_size - 1
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, path_or_repo_id: str, *, revision: Optional[str] = None cls, path_or_repo_id: str, *, revision: Optional[str] = None
) -> "MistralTokenizer": ) -> "MistralTokenizer":
if not Path(path_or_repo_id).exists(): from transformers.tokenization_mistral_common import (
assert len(path_or_repo_id.split("/")) == 2, ( MistralCommonTokenizer as TransformersMistralTokenizer,
"You have either provided a non-existent path: "
"{path_or_repo_id} or an invalid HF Hub repo id."
)
tokenizer_file = cls._download_mistral_tokenizer_from_hf(
path_or_repo_id, revision
) )
elif Path(path_or_repo_id).is_dir():
tokenizer_file_name = find_tokenizer_file(os.listdir(path_or_repo_id))
tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
else:
assert Path(path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
tokenizer_file = str(Path(path_or_repo_id))
from mistral_common.tokens.tokenizers.mistral import ( str_revision = "main" if revision is None else revision
MistralTokenizer as PublicMistralTokenizer, return cls(
TransformersMistralTokenizer.from_pretrained(
path_or_repo_id, revision=str_revision
) )
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
return cls(mistral_tokenizer)
@staticmethod
def _download_mistral_tokenizer_from_hf(
tokenizer_name: str, revision: Optional[str]
) -> str:
try:
hf_api = HfApi()
files = hf_api.list_repo_files(repo_id=tokenizer_name, revision=revision)
except ConnectionError as exc:
files = list_local_repo_files(repo_id=tokenizer_name, revision=revision)
if len(files) == 0:
raise exc
filename = find_tokenizer_file(files)
tokenizer_file = hf_hub_download(
tokenizer_name, filename=filename, revision=revision
) )
return tokenizer_file
# the following attributes are set to fit vLLM's design and are used # the following attributes are set to fit vLLM's design and are used
# by the structured output backends. # by the structured output backends.
@property @property
def all_special_tokens_extended(self) -> list[str]: def all_special_tokens_extended(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens return self.all_special_tokens
# tekken defines its own extended special tokens list
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
special_tokens = self.tokenizer.SPECIAL_TOKENS
else:
special_tokens = list(SpecialTokens)
return [s.value if isinstance(s, SpecialTokens) else s for s in special_tokens]
@property @property
def all_special_tokens(self) -> list[str]: def all_special_tokens(self) -> list[str]:
return self.all_special_tokens_extended from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in self.all_special_ids
]
@property @property
def all_special_ids(self) -> list[int]: def all_special_ids(self) -> list[int]:
return [self.all_special_tokens.index(t) for t in self.all_special_tokens] from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
elif self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
special_ids = self.tokenizer._control_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids)
@property @property
def bos_token_id(self) -> int: def bos_token_id(self) -> int:
...@@ -317,7 +258,7 @@ class MistralTokenizer(TokenizerBase): ...@@ -317,7 +258,7 @@ class MistralTokenizer(TokenizerBase):
@property @property
def pad_token(self) -> str: def pad_token(self) -> str:
raise NotImplementedError() return self.transformers_tokenizer.pad_token
@property @property
def is_fast(self) -> bool: def is_fast(self) -> bool:
...@@ -325,7 +266,7 @@ class MistralTokenizer(TokenizerBase): ...@@ -325,7 +266,7 @@ class MistralTokenizer(TokenizerBase):
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
return len(self._vocab) return self.transformers_tokenizer.vocab_size
@property @property
def max_token_id(self) -> int: def max_token_id(self) -> int:
...@@ -335,6 +276,23 @@ class MistralTokenizer(TokenizerBase): ...@@ -335,6 +276,23 @@ class MistralTokenizer(TokenizerBase):
def truncation_side(self) -> str: def truncation_side(self) -> str:
raise NotImplementedError() raise NotImplementedError()
def _is_special_token_id(self, token_id: int) -> bool:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
return token_id in self.tokenizer._control_tokens
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
return token_id < self.tokenizer.num_special_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
def __len__(self) -> int: def __len__(self) -> int:
return self.vocab_size return self.vocab_size
...@@ -346,25 +304,19 @@ class MistralTokenizer(TokenizerBase): ...@@ -346,25 +304,19 @@ class MistralTokenizer(TokenizerBase):
truncation: bool = False, truncation: bool = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
): ):
input_ids: Union[list[int], list[list[int]]] return self.transformers_tokenizer(
# For list[str], original prompt text text=text,
if is_list_of(text, str): text_pair=text_pair,
input_ids_: list[list[int]] = [] add_special_tokens=add_special_tokens,
for p in text: truncation=truncation,
each_input_ids = self.encode_one(p, truncation, max_length) max_length=max_length,
input_ids_.append(each_input_ids) )
input_ids = input_ids_
# For list[int], apply chat template output, already tokens. @property
elif is_list_of(text, int): def vocab(self) -> list[str]:
input_ids = text return self._vocab
# For str, single prompt text
else:
input_ids = self.encode_one(text, truncation, max_length)
return BatchEncoding({"input_ids": input_ids})
def get_vocab(self) -> dict[str, int]: def get_vocab(self) -> dict[str, int]:
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
return self._vocab_dict return self._vocab_dict
def get_added_vocab(self) -> dict[str, int]: def get_added_vocab(self) -> dict[str, int]:
...@@ -378,11 +330,9 @@ class MistralTokenizer(TokenizerBase): ...@@ -378,11 +330,9 @@ class MistralTokenizer(TokenizerBase):
max_length: Optional[int] = None, max_length: Optional[int] = None,
) -> list[int]: ) -> list[int]:
# Mistral Tokenizers should not add special tokens # Mistral Tokenizers should not add special tokens
input_ids = self.encode(text) return self.transformers_tokenizer.encode(
text, add_special_tokens=False, truncation=truncation, max_length=max_length
if truncation: )
input_ids = input_ids[:max_length]
return input_ids
def encode( def encode(
self, self,
...@@ -391,15 +341,20 @@ class MistralTokenizer(TokenizerBase): ...@@ -391,15 +341,20 @@ class MistralTokenizer(TokenizerBase):
max_length: Optional[int] = None, max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None, add_special_tokens: Optional[bool] = None,
) -> list[int]: ) -> list[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
if add_special_tokens is not None: if add_special_tokens is not None:
return self.tokenizer.encode( return self.transformers_tokenizer.encode(
text, bos=add_special_tokens, eos=add_special_tokens text,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
) )
else: else:
return self.tokenizer.encode(text, bos=True, eos=False) encoded = self.tokenizer.encode(text, bos=True, eos=False)
if truncation is not False and max_length is not None:
return encoded[:max_length]
else:
return encoded
def apply_chat_template( def apply_chat_template(
self, self,
...@@ -407,59 +362,79 @@ class MistralTokenizer(TokenizerBase): ...@@ -407,59 +362,79 @@ class MistralTokenizer(TokenizerBase):
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
**kwargs, **kwargs,
) -> list[int]: ) -> list[int]:
request = make_mistral_chat_completion_request(messages, tools) add_generation_prompt = kwargs.pop("add_generation_prompt", False)
encoded = self.mistral.encode_chat_completion(request) continue_final_message = kwargs.get("continue_final_message", False)
padding = kwargs.get("padding", False)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length")
messages, tools = _prepare_apply_chat_template_tools_and_messages(
messages, tools, continue_final_message, add_generation_prompt
)
# encode-decode to get clean prompt return self.transformers_tokenizer.apply_chat_template(
return encoded.tokens conversation=messages,
tools=tools,
continue_final_message=continue_final_message,
tokenize=True,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=None,
return_dict=False,
)
def decode(
self, ids: Union[list[int], int], skip_special_tokens: bool = True
) -> str:
return self.transformers_tokenizer.decode(
ids, skip_special_tokens=skip_special_tokens
)
def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
from mistral_common.tokens.tokenizers.base import SpecialTokens from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken: if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
tokens = [ tokens = [
t t
for t in tokens for t in tokens
if ( if (t in to_decode_special_tokens or t not in self.all_special_tokens)
t is SpecialTokens.tool_calls
or t not in self.tokenizer._all_special_tokens
)
] ]
if any(isinstance(t, bytes) for t in tokens): if any(isinstance(t, bytes) for t in tokens):
# we need to encode and decode all tokens again # we need to encode and decode all tokens again
shift = self.tokenizer.num_special_tokens ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
# We filtered unwanted special tokens before
def _token_to_id(t: str): # so we can decode the rest.
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
try:
return (
shift + self.tokenizer._tekken_token2id_nospecial[t_bytes]
)
except KeyError:
logger.warning(
"Failed to convert token %s to id, replacing with <unk>",
t_bytes,
)
return self.tokenizer.unk_id
ids = [_token_to_id(t) for t in tokens]
decoded = self.tokenizer.decode(ids, self._special_token_policy)
else: else:
decoded = "".join(tokens) decoded = "".join(tokens)
else: else:
# make sure certain special tokens like Tool calls are # make sure certain special tokens like Tool calls are
# not decoded # not decoded
special_tokens = {SpecialTokens.tool_calls} assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
regular_tokens: list[str] = [] regular_tokens: list[str] = []
decoded_list = [] decoded_list: list[str] = []
decoded = ""
for token in tokens: for token in tokens:
if token in special_tokens: if token in to_decode_special_tokens:
if regular_tokens: if regular_tokens:
decoded_list.append( decoded_list.append(
self.tokenizer.decode( self.tokenizer.decode(
regular_tokens, self._special_token_policy regular_tokens, SpecialTokenPolicy.IGNORE
) )
) )
regular_tokens = [] regular_tokens = []
...@@ -469,66 +444,56 @@ class MistralTokenizer(TokenizerBase): ...@@ -469,66 +444,56 @@ class MistralTokenizer(TokenizerBase):
if regular_tokens: if regular_tokens:
decoded_list.append( decoded_list.append(
self.tokenizer.decode(regular_tokens, self._special_token_policy) self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
) )
decoded = "".join(decoded_list) decoded = "".join(decoded_list)
return decoded return decoded
def decode(
self, ids: Union[list[int], int], skip_special_tokens: bool = True
) -> str:
assert skip_special_tokens, (
"skip_special_tokens=False is not supported for Mistral tokenizers."
)
if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids, self._special_token_policy)
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, self,
ids: list[int], ids: list[int],
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
) -> list[str]: ) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens from mistral_common.tokens.tokenizers.base import (
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13 SpecialTokenPolicy,
SpecialTokens,
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert skip_special_tokens, (
"skip_special_tokens=False is not supported for Mistral tokenizers."
) )
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
assert self.is_tekken or self.is_spm, type(self.tokenizer) if not skip_special_tokens:
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
if self.is_tekken: non_skip_special_tokens_ids = {
# skip special tokens except tool call and think tokens self.tokenizer.get_control_token(SpecialTokens.tool_calls),
non_skip_special_tokens = {
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
} }
if isinstance(self.instruct, InstructTokenizerV13): if isinstance(self.instruct, InstructTokenizerV13):
if self.instruct.BEGIN_THINK: if self.instruct.BEGIN_THINK:
non_skip_special_tokens.add(self.instruct.BEGIN_THINK) non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
if self.instruct.END_THINK: if self.instruct.END_THINK:
non_skip_special_tokens.add(self.instruct.END_THINK) non_skip_special_tokens_ids.add(self.instruct.END_THINK)
ids = [
ids_kept = [
i i
for i in ids for i in ids
if i > self.tokenizer.num_special_tokens or i in non_skip_special_tokens if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
] ]
tokens = [self.tokenizer.id_to_piece(id) for id in ids] # We filtered unwanted special tokens so we can decode the rest.
tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
if any("�" in t for t in tokens) and self.is_tekken: if any("�" in t for t in tokens) and self.is_tekken:
# if a decoded token contains the replacement character, then the # if a decoded token contains the replacement character, then the
# token has an incomplete UTF-8 character so we must use bytes # token has an incomplete UTF-8 character so we must use bytes
# See: https://github.com/vllm-project/vllm/pull/8640 # See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625 # https://github.com/vllm-project/vllm/pull/9625
# if underlying tokenizeir is sentencepiece, we just add "�" # if underlying tokenizer is sentencepiece, we just add "�".
# We filtered unwanted special tokens so we can decode the rest.
tokens = [ tokens = [
self.tokenizer.id_to_byte_piece(id, self._special_token_policy) self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
for id in ids if token_id not in self.all_special_ids
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
for token_id in ids_kept
] ]
return tokens return tokens
...@@ -43,34 +43,13 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -43,34 +43,13 @@ class XgrammarBackend(StructuredOutputBackend):
if isinstance(self.tokenizer, MistralTokenizer): if isinstance(self.tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly. # NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
if self.tokenizer.is_tekken:
encoded_vocab = self.tokenizer._vocab
else:
encoded_vocab = [
token
for token, _ in sorted(
self.tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if (
hasattr(
self.tokenizer,
"eos_token_id",
)
and self.tokenizer.eos_token_id is not None
):
stop_token_ids = [self.tokenizer.eos_token_id] stop_token_ids = [self.tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError( # not self.tokenizer.vocab_size as self.tokenizer.vocab
f"Cannot get the vocabulary of the tokenizer " # collapses all decoded errors into a single token.
f"{type(self.tokenizer)}. The tokenizer should have a " self.vocab_size = len(self.tokenizer.vocab)
"get_vocab method."
) from e
tokenizer_info = xgr.TokenizerInfo( # type: ignore tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab, encoded_vocab=self.tokenizer.vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.RAW vocab_type=xgr.VocabType.RAW
if self.tokenizer.is_tekken if self.tokenizer.is_tekken
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment