Unverified Commit c6187f55 authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

Refactor MistralTokenizer (#26358)


Signed-off-by: default avatarJulien Denize <julien.denize@mistral.ai>
parent 8983e021
...@@ -145,7 +145,7 @@ Supported models: ...@@ -145,7 +145,7 @@ Supported models:
Known issues: Known issues:
1. Mistral 7B struggles to generate parallel tool calls correctly. 1. Mistral 7B struggles to generate parallel tool calls correctly.
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is 2. **For Transformers tokenization backend only**: Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
much shorter than what vLLM generates. Since an exception is thrown when this condition much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided: is not met, the following additional chat templates are provided:
...@@ -154,7 +154,14 @@ Known issues: ...@@ -154,7 +154,14 @@ Known issues:
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt * <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling. when tools are provided, that results in much better reliability when working with parallel tool calling.
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` Recommended flags:
1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend:
`--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral`
2. To use the default Transformers tokenization backend:
`--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
### Llama Models (`llama3_json`) ### Llama Models (`llama3_json`)
......
...@@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple): ...@@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple):
# Voxtral # Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData: def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import ( from mistral_common.protocol.instruct.chunk import (
AudioChunk, AudioChunk,
RawAudio, RawAudio,
TextChunk, TextChunk,
)
from mistral_common.protocol.instruct.messages import (
UserMessage, UserMessage,
) )
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
......
...@@ -32,7 +32,7 @@ pyzmq >= 25.0.0 ...@@ -32,7 +32,7 @@ pyzmq >= 25.0.0
msgspec msgspec
gguf >= 0.13.0 gguf >= 0.13.0
importlib_metadata; python_version < '3.10' importlib_metadata; python_version < '3.10'
mistral_common[image,audio] >= 1.8.2 mistral_common[image,audio] >= 1.8.5
opencv-python-headless >= 4.11.0 # required for video IO opencv-python-headless >= 4.11.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
......
...@@ -23,7 +23,7 @@ jiwer # required for audio tests ...@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test timm # required for internvl test
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
......
...@@ -29,7 +29,7 @@ torchaudio==2.8.0 ...@@ -29,7 +29,7 @@ torchaudio==2.8.0
torchvision==0.23.0 torchvision==0.23.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test open_clip_torch==2.32.0 # Required for nemotron_vl test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test
......
...@@ -474,7 +474,7 @@ mbstrdecoder==1.1.3 ...@@ -474,7 +474,7 @@ mbstrdecoder==1.1.3
# typepy # typepy
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.8.2 mistral-common==1.8.5
# via -r requirements/test.in # via -r requirements/test.in
mlflow==2.22.0 mlflow==2.22.0
# via terratorch # via terratorch
...@@ -1012,8 +1012,6 @@ sentence-transformers==3.2.1 ...@@ -1012,8 +1012,6 @@ sentence-transformers==3.2.1
# via # via
# -r requirements/test.in # -r requirements/test.in
# mteb # mteb
sentencepiece==0.2.0
# via mistral-common
setuptools==77.0.3 setuptools==77.0.3
# via # via
# lightning-utilities # lightning-utilities
......
...@@ -6,8 +6,7 @@ from collections.abc import Mapping ...@@ -6,8 +6,7 @@ from collections.abc import Mapping
from typing import Literal, Optional from typing import Literal, Optional
import pytest import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
...@@ -2119,34 +2118,9 @@ def test_apply_mistral_chat_template_thinking_chunk(): ...@@ -2119,34 +2118,9 @@ def test_apply_mistral_chat_template_thinking_chunk():
}, },
{"role": "user", "content": "Thanks, what is 3+3?"}, {"role": "user", "content": "Thanks, what is 3+3?"},
] ]
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained( mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507" "mistralai/Magistral-Small-2509"
)
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
) )
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
tokens_ids = apply_mistral_chat_template( tokens_ids = apply_mistral_chat_template(
mistral_tokenizer, messages, chat_template=None, tools=None mistral_tokenizer, messages, chat_template=None, tools=None
......
...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional
import pytest import pytest
from mistral_common.multimodal import download_image from mistral_common.multimodal import download_image
from mistral_common.protocol.instruct.messages import ImageURLChunk from mistral_common.protocol.instruct.chunk import ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
......
...@@ -6,12 +6,8 @@ import json ...@@ -6,12 +6,8 @@ import json
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from mistral_common.audio import Audio from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import ( from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
AudioChunk, from mistral_common.protocol.instruct.messages import UserMessage
RawAudio,
TextChunk,
UserMessage,
)
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.transformers_utils.tokenizer import MistralTokenizer
......
...@@ -6,7 +6,8 @@ from typing import Optional, Union ...@@ -6,7 +6,8 @@ from typing import Optional, Union
import numpy as np import numpy as np
import pytest import pytest
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
......
...@@ -9,7 +9,8 @@ from typing import Any, Union ...@@ -9,7 +9,8 @@ from typing import Any, Union
import numpy as np import numpy as np
import pytest import pytest
import torch.nn as nn import torch.nn as nn
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from tests.reasoning.utils import run_reasoning_extraction_mistral from tests.reasoning.utils import run_reasoning_extraction_mistral
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
...@@ -14,33 +12,9 @@ parser_name = "mistral" ...@@ -14,33 +12,9 @@ parser_name = "mistral"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def mistral_tokenizer(): def mistral_tokenizer():
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained( mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507" "mistralai/Magistral-Small-2509"
) )
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
return mistral_tokenizer return mistral_tokenizer
......
...@@ -403,20 +403,12 @@ def resolve_mistral_chat_template( ...@@ -403,20 +403,12 @@ def resolve_mistral_chat_template(
chat_template: Optional[str], chat_template: Optional[str],
**kwargs: Any, **kwargs: Any,
) -> Optional[str]: ) -> Optional[str]:
if chat_template is not None: if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
logger.warning_once( raise ValueError(
"'chat_template' cannot be overridden for mistral tokenizer." "'chat_template' or 'chat_template_kwargs' cannot be overridden "
) "for mistral tokenizer."
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
) )
return None return None
......
...@@ -10,7 +10,8 @@ from typing import Annotated, Literal, Optional, Union ...@@ -10,7 +10,8 @@ from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image from PIL import Image
......
...@@ -12,12 +12,8 @@ import regex as re ...@@ -12,12 +12,8 @@ import regex as re
import torch import torch
import torch.nn as nn import torch.nn as nn
from mistral_common.audio import mel_filter_bank from mistral_common.audio import mel_filter_bank
from mistral_common.protocol.instruct.messages import ( from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
AudioChunk, from mistral_common.protocol.instruct.messages import UserMessage
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
......
...@@ -43,34 +43,13 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -43,34 +43,13 @@ class XgrammarBackend(StructuredOutputBackend):
if isinstance(self.tokenizer, MistralTokenizer): if isinstance(self.tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly. # NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try: stop_token_ids = [self.tokenizer.eos_token_id]
if self.tokenizer.is_tekken:
encoded_vocab = self.tokenizer._vocab # not self.tokenizer.vocab_size as self.tokenizer.vocab
else: # collapses all decoded errors into a single token.
encoded_vocab = [ self.vocab_size = len(self.tokenizer.vocab)
token
for token, _ in sorted(
self.tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if (
hasattr(
self.tokenizer,
"eos_token_id",
)
and self.tokenizer.eos_token_id is not None
):
stop_token_ids = [self.tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(self.tokenizer)}. The tokenizer should have a "
"get_vocab method."
) from e
tokenizer_info = xgr.TokenizerInfo( # type: ignore tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab, encoded_vocab=self.tokenizer.vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.RAW vocab_type=xgr.VocabType.RAW
if self.tokenizer.is_tekken if self.tokenizer.is_tekken
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment