Unverified Commit c6187f55 authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

Refactor MistralTokenizer (#26358)


Signed-off-by: default avatarJulien Denize <julien.denize@mistral.ai>
parent 8983e021
......@@ -145,7 +145,7 @@ Supported models:
Known issues:
1. Mistral 7B struggles to generate parallel tool calls correctly.
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
2. **For Transformers tokenization backend only**: Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided:
......@@ -154,7 +154,14 @@ Known issues:
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling.
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
Recommended flags:
1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend:
`--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral`
2. To use the default Transformers tokenization backend:
`--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
### Llama Models (`llama3_json`)
......
......@@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple):
# Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import (
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
RawAudio,
TextChunk,
)
from mistral_common.protocol.instruct.messages import (
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
......
......@@ -32,7 +32,7 @@ pyzmq >= 25.0.0
msgspec
gguf >= 0.13.0
importlib_metadata; python_version < '3.10'
mistral_common[image,audio] >= 1.8.2
mistral_common[image,audio] >= 1.8.5
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
......
......@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
......
......@@ -29,7 +29,7 @@ torchaudio==2.8.0
torchvision==0.23.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test
opencv-python-headless >= 4.11.0 # required for video test
......
......@@ -474,7 +474,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.8.2
mistral-common==1.8.5
# via -r requirements/test.in
mlflow==2.22.0
# via terratorch
......@@ -1012,8 +1012,6 @@ sentence-transformers==3.2.1
# via
# -r requirements/test.in
# mteb
sentencepiece==0.2.0
# via mistral-common
setuptools==77.0.3
# via
# lightning-utilities
......
......@@ -6,8 +6,7 @@ from collections.abc import Mapping
from typing import Literal, Optional
import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
......@@ -2119,34 +2118,9 @@ def test_apply_mistral_chat_template_thinking_chunk():
},
{"role": "user", "content": "Thanks, what is 3+3?"},
]
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507"
)
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
"mistralai/Magistral-Small-2509"
)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
tokens_ids = apply_mistral_chat_template(
mistral_tokenizer, messages, chat_template=None, tools=None
......
......@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional
import pytest
from mistral_common.multimodal import download_image
from mistral_common.protocol.instruct.messages import ImageURLChunk
from mistral_common.protocol.instruct.chunk import ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
......
......@@ -6,12 +6,8 @@ import json
import pytest
import pytest_asyncio
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import (
AudioChunk,
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from vllm.transformers_utils.tokenizer import MistralTokenizer
......
......@@ -6,7 +6,8 @@ from typing import Optional, Union
import numpy as np
import pytest
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
......
......@@ -9,7 +9,8 @@ from typing import Any, Union
import numpy as np
import pytest
import torch.nn as nn
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
......
......@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from tests.reasoning.utils import run_reasoning_extraction_mistral
from vllm.reasoning import ReasoningParser, ReasoningParserManager
......@@ -14,33 +12,9 @@ parser_name = "mistral"
@pytest.fixture(scope="module")
def mistral_tokenizer():
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507"
"mistralai/Magistral-Small-2509"
)
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
return mistral_tokenizer
......
......@@ -403,20 +403,12 @@ def resolve_mistral_chat_template(
chat_template: Optional[str],
**kwargs: Any,
) -> Optional[str]:
if chat_template is not None:
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer."
)
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
raise ValueError(
"'chat_template' or 'chat_template_kwargs' cannot be overridden "
"for mistral tokenizer."
)
return None
......
......@@ -10,7 +10,8 @@ from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
......
......@@ -12,12 +12,8 @@ import regex as re
import torch
import torch.nn as nn
from mistral_common.audio import mel_filter_bank
from mistral_common.protocol.instruct.messages import (
AudioChunk,
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
......
......@@ -43,34 +43,13 @@ class XgrammarBackend(StructuredOutputBackend):
if isinstance(self.tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
if self.tokenizer.is_tekken:
encoded_vocab = self.tokenizer._vocab
else:
encoded_vocab = [
token
for token, _ in sorted(
self.tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if (
hasattr(
self.tokenizer,
"eos_token_id",
)
and self.tokenizer.eos_token_id is not None
):
stop_token_ids = [self.tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(self.tokenizer)}. The tokenizer should have a "
"get_vocab method."
) from e
# not self.tokenizer.vocab_size as self.tokenizer.vocab
# collapses all decoded errors into a single token.
self.vocab_size = len(self.tokenizer.vocab)
tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab,
encoded_vocab=self.tokenizer.vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.RAW
if self.tokenizer.is_tekken
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment