mistral.py 21.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from pathlib import Path
4
from typing import TYPE_CHECKING, Any, cast
5

6
from vllm.logger import init_logger
7
8

from .protocol import TokenizerLike
9
from .registry import TokenizerRegistry
10

11
if TYPE_CHECKING:
12
13
14
15
    from mistral_common.protocol.instruct.request import (
        ChatCompletionRequest as MistralChatCompletionRequest,
    )
    from mistral_common.tokens.tokenizers.tekken import Tekkenizer
16
    from transformers import BatchEncoding
17

18
    from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
19
    from vllm.entrypoints.openai.protocol import ChatCompletionRequest
20

21
22
23
24
25
26
27
28
29
    try:
        # Transformers v5
        from transformers.tokenization_mistral_common import MistralCommonBackend
    except ImportError:
        # Transformers v4
        from transformers.tokenization_mistral_common import (
            MistralCommonTokenizer as MistralCommonBackend,
        )

30
31
logger = init_logger(__name__)

32

33
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
    # declared as iterables are replaced in in the instances by
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
57
        if message.get("role") == "assistant":
58
59
60
61
62
63
64
65
66
67
68
69
            tool_calls_validator = message.get("tool_calls", ().__iter__())
            validated_tool_calls = []
            while True:
                try:
                    tool_call = next(tool_calls_validator)  # type: ignore
                    validated_tool_calls.append(tool_call)
                except StopIteration:
                    break

            request.messages[i]["tool_calls"] = validated_tool_calls


70
def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
71
72
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
73
        if message.get("role") == "assistant":
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


100
101
def _prepare_apply_chat_template_tools_and_messages(
    messages: list["ChatCompletionMessageParam"],
102
    tools: list[dict[str, Any]] | None = None,
103
104
    continue_final_message: bool = False,
    add_generation_prompt: bool = False,
105
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
106
107
    from mistral_common.protocol.instruct.tool_calls import Function, Tool

108
    if add_generation_prompt and continue_final_message:
109
        raise ValueError(
110
111
            "Cannot set both `add_generation_prompt` and "
            "`continue_final_message` to True."
112
        )
113

114
115
116
117
118
119
120
121
122
123
    last_message = cast(dict[str, Any], messages[-1])
    # add_generation_prompt is directly handled by the tokenizer but we
    # check if the user is trying to use it with a final assistant message
    # which is probably not what they want.
    # If add_generation_prompt is False, we don't need to check anything.
    if add_generation_prompt and last_message["role"] == "assistant":
        raise ValueError(
            "Cannot set `add_generation_prompt` to True when "
            "the last message is from the assistant. Consider "
            "using `continue_final_message` instead."
124
        )
125
126
127
128
    if continue_final_message and last_message["role"] != "assistant":
        raise ValueError(
            "Cannot set `continue_final_message` to True when "
            "the last message is not from the assistant."
129
        )
130

131
132
133
134
    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
135
136
        # Remove reasoning as unsupported by Mistral
        _ = message.pop("reasoning", None)  # type: ignore
137

138
    # The Mistral client, in comparison to the OpenAI client, requires the
139
140
    # "parameters" dict and the "description" string to be present
    # even if they are empty.
141
142
    if tools:
        for function in [
143
            tool["function"] for tool in tools if tool["type"] == "function"
144
        ]:
145
146
            if function.get("parameters") is None:
                function["parameters"] = {}
147
148
            if function.get("description") is None:
                function["description"] = ""
149

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        # We filter not supported arguments to avoid throwing an error.
        # TODO(juliendenize): remove this once OpenAI API is better supported by
        # `mistral-common`.
        tools_fields = set(Tool.model_fields.keys())
        function_fields = set(Function.model_fields.keys())
        for tool in tools:
            tool_keys = list(tool.keys())
            for tool_key in tool_keys:
                if tool_key not in tools_fields:
                    tool.pop(tool_key)
                    logger.warning_once(
                        f"'{tool_key}' is not supported by mistral-common for tools. "
                        "It has been poped from the tool definition."
                    )
                if tool["type"] == "function":
                    function_keys = list(tool["function"].keys())
                    for function_key in function_keys:
                        if function_key not in function_fields:
                            tool["function"].pop(function_key)
                            logger.warning_once(
                                f"'{function_key}' is not supported by mistral-common "
                                "for function tools. It has been poped from the "
                                "function definition."
                            )
                else:
                    raise ValueError("mistral-common only supports function tools.")

177
    return messages, tools
178
179


180
181
182
def validate_request_params(request: "ChatCompletionRequest"):
    if request.chat_template is not None or request.chat_template_kwargs is not None:
        raise ValueError("chat_template is not supported for Mistral tokenizers.")
183
184


185
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    from mistral_common.tokens.tokenizers.tekken import Tekkenizer

    assert isinstance(tokenizer, Tekkenizer), type(tokenizer)

    t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
    shift = tokenizer.num_special_tokens
    try:
        return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
    except KeyError:
        t_str = t_bytes.decode("utf-8")
        if t_str in tokenizer._special_tokens_reverse_vocab:
            return tokenizer._special_tokens_reverse_vocab[t_str]
        logger.warning(
            "Failed to convert token %s to id, replacing with <unk>", t_bytes
        )
        return tokenizer.unk_id

203

204
@TokenizerRegistry.register("mistral")
205
class MistralTokenizer(TokenizerLike):
206
207
208
209
210
211
212
213
214
215
216
217
    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "MistralTokenizer":
        from mistral_common.protocol.instruct.validator import ValidationMode

218
219
220
221
222
223
224
225
226
227
        try:
            # Transformers v5
            from transformers.tokenization_mistral_common import MistralCommonBackend
        except ImportError:
            # Transformers v4
            from transformers.tokenization_mistral_common import (
                MistralCommonTokenizer as MistralCommonBackend,
            )

        tokenizer = MistralCommonBackend.from_pretrained(
228
229
230
231
232
233
234
235
236
237
            path_or_repo_id,
            *args,
            mode=ValidationMode.test,
            cache_dir=download_dir,
            revision="main" if revision is None else revision,
            **kwargs,
        )

        return cls(tokenizer)

238
    def __init__(self, tokenizer: "MistralCommonBackend") -> None:
239
240
        super().__init__()

241
        from mistral_common.protocol.instruct.validator import ValidationMode
242
        from mistral_common.tokens.tokenizers.sentencepiece import (
243
244
            SentencePieceTokenizer,
        )
245
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer
246

247
248
249
250
251
        self.transformers_tokenizer = tokenizer
        self.mistral = tokenizer.tokenizer
        self.instruct = self.mistral.instruct_tokenizer
        self.tokenizer = self.instruct.tokenizer

252
253
254
255
256
257
258
259
        mode = self.mistral._chat_completion_request_validator._mode
        if mode != ValidationMode.test:
            raise ValueError(
                "Mistral tokenizer must be in test mode. Make sure to "
                "set `mode='ValidationMode.test'` when creating the "
                "Mistral tokenizer."
            )

260
261
262
263
264
        _mistral_version_str = str(self.tokenizer.version.value)
        self.version: int = int(_mistral_version_str.split("v")[-1])

        self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
        self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
265
        if not (self.is_tekken or self.is_spm):
266
267
268
269
270
271
272
273
274
275
            raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")

        # Reverse order to ensure that the lowest token id is kept.
        self._vocab_dict = {
            self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
            for i in range(self.vocab_size - 1, -1, -1)
        }
        # Sort the dict for convenience
        self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))

276
277
278
279
280
281
        # Cache special tokens for faster access.
        self._special_token_ids = self._get_special_token_ids()
        self._special_token_ids_set = set(self._special_token_ids)
        self._special_tokens = self._get_special_tokens(self._special_token_ids)
        self._special_tokens_set = set(self._special_tokens)

282
283
        # Vocab sorted by token id.
        self._vocab = self.tokenizer._vocab
284
        self._max_token_id = self.vocab_size - 1
285

286
    def _get_special_token_ids(self) -> list[int]:
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        from mistral_common.tokens.tokenizers.sentencepiece import (
            SentencePieceTokenizer,
        )
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer

        if self.is_tekken:
            assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
            special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
        elif self.is_spm:
            assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
                self.tokenizer
            )
            special_ids = self.tokenizer._control_tokens
        else:
            raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
        return sorted(special_ids)
303

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
        from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy

        return [
            self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
            for i in all_special_ids
        ]

    # the following attributes are set to fit vLLM's design and are used
    # by the structured output backends.
    @property
    def all_special_tokens(self) -> list[str]:
        return self._special_tokens

    @property
    def all_special_ids(self) -> list[int]:
        return self._special_token_ids

322
323
324
325
326
327
328
329
    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

330
331
332
333
    @property
    def pad_token_id(self) -> int:
        return self.tokenizer.pad_id

334
335
336
337
338
339
    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
340
        return self.transformers_tokenizer.vocab_size
341

342
343
344
345
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

346
347
    @property
    def truncation_side(self) -> str:
348
        return self.transformers_tokenizer.truncation_side
349

350
    def _is_special_token_id(self, token_id: int) -> bool:
351
        return token_id in self._special_token_ids_set
352

353
354
355
    def __hash__(self) -> int:
        return hash(id(self))

356
357
358
    def __len__(self) -> int:
        return self.vocab_size

359
360
    def __call__(
        self,
361
        text: str | list[str],
362
        text_pair: str | None = None,
363
        add_special_tokens: bool = True,
364
        truncation: bool = False,
365
        max_length: int | None = None,
366
    ) -> "BatchEncoding":
367
368
369
370
371
372
        if text_pair is not None:
            raise ValueError(
                "`text_pair` is not supported by `MistralTokenizer.__call__`."
            )

        encoded = self.transformers_tokenizer(
373
374
375
376
377
378
            text=text,
            text_pair=text_pair,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length,
        )
379
380
381
382
383
384
385
386
387
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, revert to only call self.transformers_tokenizer(...).
        # Hack to fix wrongly added eos token, when fix will be supported the condition
        # below will be False even before the revert is done.
        if encoded["input_ids"] and encoded["input_ids"][-1] == self.eos_token_id:
            encoded["input_ids"].pop(-1)
            if attention_mask := encoded.get("attention_mask"):
                attention_mask.pop(-1)
        return encoded
388
389
390
391

    @property
    def vocab(self) -> list[str]:
        return self._vocab
392

393
    def get_vocab(self) -> dict[str, int]:
394
        return self._vocab_dict
395

396
    def get_added_vocab(self) -> dict[str, int]:
397
        # Mistral tokenizers have no added vocabulary
398
        return {}
399

400
401
402
    def encode(
        self,
        text: str,
403
404
        truncation: bool | None = None,
        max_length: int | None = None,
405
        add_special_tokens: bool = True,
406
    ) -> list[int]:
407
408
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.encode(...).
409
        encoded = self.tokenizer.encode(text, bos=add_special_tokens, eos=False)
410

411
412
413
414
        if truncation is not False and max_length is not None:
            return encoded[:max_length]
        else:
            return encoded
415

416
417
418
    def apply_chat_template(
        self,
        messages: list["ChatCompletionMessageParam"],
419
        tools: list[dict[str, Any]] | None = None,
420
421
        **kwargs,
    ) -> list[int]:
422
423
424
425
426
427
428
429
430
        add_generation_prompt = kwargs.pop("add_generation_prompt", False)
        continue_final_message = kwargs.get("continue_final_message", False)
        padding = kwargs.get("padding", False)
        truncation = kwargs.get("truncation", False)
        max_length = kwargs.get("max_length")

        messages, tools = _prepare_apply_chat_template_tools_and_messages(
            messages, tools, continue_final_message, add_generation_prompt
        )
431

432
433
434
435
436
437
438
439
440
441
442
443
        return self.transformers_tokenizer.apply_chat_template(
            conversation=messages,
            tools=tools,
            continue_final_message=continue_final_message,
            tokenize=True,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=None,
            return_dict=False,
        )

444
    def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
445
446
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.decode(...).
447
448
449
        if isinstance(ids, int):
            ids = [ids]

450
451
452
        return self.transformers_tokenizer.decode(
            ids, skip_special_tokens=skip_special_tokens
        )
453

454
455
456
457
458
459
460
    def batch_decode(
        self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False
    ) -> str:
        return self.transformers_tokenizer.batch_decode(
            ids, skip_special_tokens=skip_special_tokens
        )

461
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
462
463
464
465
466
467
468
469
        from mistral_common.tokens.tokenizers.base import (
            SpecialTokenPolicy,
            SpecialTokens,
        )
        from mistral_common.tokens.tokenizers.sentencepiece import (
            SentencePieceTokenizer,
        )
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer
470

471
        to_decode_special_tokens = {SpecialTokens.tool_calls}
472
        if self.is_tekken:
473
            assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
474
            tokens = [
475
476
                t
                for t in tokens
477
                if (t in to_decode_special_tokens or t not in self._special_tokens_set)
478
479
480
481
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
482
483
484
485
                ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
                # We filtered unwanted special tokens before
                # so we can decode the rest.
                decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
486
487
            else:
                decoded = "".join(tokens)
488
        else:
489
490
            # make sure certain special tokens like Tool calls are
            # not decoded
491
492
493
494
            assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
                self.tokenizer
            )

495
            regular_tokens: list[str] = []
496
497
            decoded_list: list[str] = []
            decoded = ""
498
499

            for token in tokens:
500
                if token in to_decode_special_tokens:
501
502
                    if regular_tokens:
                        decoded_list.append(
503
                            self.tokenizer.decode(
504
                                regular_tokens, SpecialTokenPolicy.IGNORE
505
506
                            )
                        )
507
508
509
510
511
512
513
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
514
                    self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
515
516
                )
            decoded = "".join(decoded_list)
517
518

        return decoded
519
520

    def convert_ids_to_tokens(
521
        self,
522
        ids: list[int],
523
        skip_special_tokens: bool = False,
524
    ) -> list[str]:
525
526
527
528
        from mistral_common.tokens.tokenizers.base import (
            SpecialTokenPolicy,
            SpecialTokens,
        )
529
        from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
530

531
532
        if not skip_special_tokens:
            return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
533

534
535
536
537
538
539
540
541
        non_skip_special_tokens_ids = {
            self.tokenizer.get_control_token(SpecialTokens.tool_calls),
        }
        if isinstance(self.instruct, InstructTokenizerV13):
            if self.instruct.BEGIN_THINK:
                non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
            if self.instruct.END_THINK:
                non_skip_special_tokens_ids.add(self.instruct.END_THINK)
542

543
544
545
546
547
        ids_kept = [
            i
            for i in ids
            if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
        ]
548

549
550
        # We filtered unwanted special tokens so we can decode the rest.
        tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
551

552
        if any("�" in t for t in tokens) and self.is_tekken:
553
554
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
555
            # See: https://github.com/vllm-project/vllm/pull/8640
556
            #      https://github.com/vllm-project/vllm/pull/9625
557
558
            # if underlying tokenizer is sentencepiece, we just add "�".
            # We filtered unwanted special tokens so we can decode the rest.
559
            tokens = [
560
                self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
561
                if token_id not in self._special_token_ids_set
562
563
                else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
                for token_id in ids_kept
564
            ]
565

566
        return tokens