mistral.py 22.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Sequence
4
from functools import cached_property
5
from pathlib import Path
6
from typing import TYPE_CHECKING, Any, cast, overload
7

8
9
from mistral_common.guidance.grammar_factory import GrammarFactory
from mistral_common.guidance.tokenizer import from_mistral_tokenizer
10
11
12
from mistral_common.protocol.instruct.request import (
    ChatCompletionRequest as MistralChatCompletionRequest,
)
Julien Denize's avatar
Julien Denize committed
13
14
15
from mistral_common.protocol.instruct.request import (
    ReasoningEffort,
)
16
17
18
19
20
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import (
    SpecialTokenPolicy,
    SpecialTokens,
21
22
23
24
25
26
27
28
    Tokenizer,
)
from mistral_common.tokens.tokenizers.instruct import (
    InstructTokenizerBase,
    InstructTokenizerV13,
)
from mistral_common.tokens.tokenizers.mistral import (
    MistralTokenizer as MistralCommonTokenizer,
29
30
31
32
33
)
from mistral_common.tokens.tokenizers.sentencepiece import (
    SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
34
from pydantic import ValidationError
35

36
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
37
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
38
from vllm.logger import init_logger
39
from vllm.tokenizers.protocol import TokenizerLike
40

41
42
43
44
45
46
47
48
try:
    # Transformers v5
    from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
    # Transformers v4
    from transformers.tokenization_mistral_common import (
        MistralCommonTokenizer as MistralCommonBackend,
    )
49

50
if TYPE_CHECKING:
51
    import llguidance
52
    from transformers import BatchEncoding
53

54
55
logger = init_logger(__name__)

56

57
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
58
59
60
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
61
    # declared as iterables are replaced in the instances by
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
81
        if message.get("role") == "assistant":
82
            if (tool_calls_validator := message.get("tool_calls", None)) is not None:
83
                try:
84
85
86
87
88
89
90
91
                    validated_tool_calls = list(tool_calls_validator)
                except ValidationError as e:
                    raise ValueError(
                        "Validating messages' `tool_calls` raised an error. "
                        "Please ensure `tool_calls` are iterable of tool calls."
                    ) from e
            else:
                validated_tool_calls = []
92
93
94
95

            request.messages[i]["tool_calls"] = validated_tool_calls


96
def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
97
98
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
99
        if message.get("role") == "assistant":
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


126
127
def _prepare_apply_chat_template_tools_and_messages(
    messages: list["ChatCompletionMessageParam"],
128
    tools: list[dict[str, Any]] | None = None,
129
130
    continue_final_message: bool = False,
    add_generation_prompt: bool = False,
131
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
132
    if add_generation_prompt and continue_final_message:
133
        raise ValueError(
134
135
            "Cannot set both `add_generation_prompt` and "
            "`continue_final_message` to True."
136
        )
137

138
139
140
141
142
143
144
145
146
147
    last_message = cast(dict[str, Any], messages[-1])
    # add_generation_prompt is directly handled by the tokenizer but we
    # check if the user is trying to use it with a final assistant message
    # which is probably not what they want.
    # If add_generation_prompt is False, we don't need to check anything.
    if add_generation_prompt and last_message["role"] == "assistant":
        raise ValueError(
            "Cannot set `add_generation_prompt` to True when "
            "the last message is from the assistant. Consider "
            "using `continue_final_message` instead."
148
        )
149
150
151
152
    if continue_final_message and last_message["role"] != "assistant":
        raise ValueError(
            "Cannot set `continue_final_message` to True when "
            "the last message is not from the assistant."
153
        )
154

155
156
157
158
    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
159
160
        # Remove reasoning as unsupported by Mistral
        _ = message.pop("reasoning", None)  # type: ignore
161

162
    # The Mistral client, in comparison to the OpenAI client, requires the
163
164
    # "parameters" dict and the "description" string to be present
    # even if they are empty.
165
166
    if tools:
        for function in [
167
            tool["function"] for tool in tools if tool["type"] == "function"
168
        ]:
169
170
            if function.get("parameters") is None:
                function["parameters"] = {}
171
172
            if function.get("description") is None:
                function["description"] = ""
173

174
175
176
177
178
179
180
181
182
183
184
185
        # We filter not supported arguments to avoid throwing an error.
        # TODO(juliendenize): remove this once OpenAI API is better supported by
        # `mistral-common`.
        tools_fields = set(Tool.model_fields.keys())
        function_fields = set(Function.model_fields.keys())
        for tool in tools:
            tool_keys = list(tool.keys())
            for tool_key in tool_keys:
                if tool_key not in tools_fields:
                    tool.pop(tool_key)
                    logger.warning_once(
                        f"'{tool_key}' is not supported by mistral-common for tools. "
Jiayi Yan's avatar
Jiayi Yan committed
186
                        "It has been popped from the tool definition."
187
188
189
190
191
192
193
194
                    )
                if tool["type"] == "function":
                    function_keys = list(tool["function"].keys())
                    for function_key in function_keys:
                        if function_key not in function_fields:
                            tool["function"].pop(function_key)
                            logger.warning_once(
                                f"'{function_key}' is not supported by mistral-common "
Jiayi Yan's avatar
Jiayi Yan committed
195
                                "for function tools. It has been popped from the "
196
197
198
199
200
                                "function definition."
                            )
                else:
                    raise ValueError("mistral-common only supports function tools.")

201
    return messages, tools
202
203


204
205
206
def validate_request_params(request: "ChatCompletionRequest"):
    if request.chat_template is not None or request.chat_template_kwargs is not None:
        raise ValueError("chat_template is not supported for Mistral tokenizers.")
207

Julien Denize's avatar
Julien Denize committed
208
209
210
211
212
213
214
215
216
    if request.reasoning_effort and request.reasoning_effort not in list(
        ReasoningEffort
    ):
        raise ValueError(
            f"reasoning_effort={request.reasoning_effort} is not supported by "
            "Mistral models. Supported values are: "
            f"{[e.value for e in ReasoningEffort]}."
        )

217

218
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    assert isinstance(tokenizer, Tekkenizer), type(tokenizer)

    t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
    shift = tokenizer.num_special_tokens
    try:
        return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
    except KeyError:
        t_str = t_bytes.decode("utf-8")
        if t_str in tokenizer._special_tokens_reverse_vocab:
            return tokenizer._special_tokens_reverse_vocab[t_str]
        logger.warning(
            "Failed to convert token %s to id, replacing with <unk>", t_bytes
        )
        return tokenizer.unk_id

234

235
class MistralTokenizer(TokenizerLike):
236
237
    IS_MISTRAL_TOKENIZER = True  # used by vllm.utils.mistral

238
239
240
241
242
243
244
245
246
247
    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "MistralTokenizer":
248
        tokenizer = MistralCommonBackend.from_pretrained(
249
250
251
252
253
254
255
256
257
258
            path_or_repo_id,
            *args,
            mode=ValidationMode.test,
            cache_dir=download_dir,
            revision="main" if revision is None else revision,
            **kwargs,
        )

        return cls(tokenizer)

259
    def __init__(self, tokenizer: MistralCommonBackend) -> None:
260
261
        super().__init__()

262
263
264
265
        self.transformers_tokenizer: MistralCommonBackend = tokenizer
        self.mistral: MistralCommonTokenizer = tokenizer.tokenizer
        self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer
        self.tokenizer: Tokenizer = self.instruct.tokenizer
266

267
268
269
270
271
272
273
274
        mode = self.mistral._chat_completion_request_validator._mode
        if mode != ValidationMode.test:
            raise ValueError(
                "Mistral tokenizer must be in test mode. Make sure to "
                "set `mode='ValidationMode.test'` when creating the "
                "Mistral tokenizer."
            )

275
276
277
278
279
        _mistral_version_str = str(self.tokenizer.version.value)
        self.version: int = int(_mistral_version_str.split("v")[-1])

        self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
        self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
280
        if not (self.is_tekken or self.is_spm):
281
282
283
284
285
286
287
288
289
290
            raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")

        # Reverse order to ensure that the lowest token id is kept.
        self._vocab_dict = {
            self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
            for i in range(self.vocab_size - 1, -1, -1)
        }
        # Sort the dict for convenience
        self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))

291
292
293
        # Vocab sorted by token id.
        self._vocab = self.tokenizer.vocab()
        self._max_token_id = self.vocab_size - 1
294
        self._max_chars_per_token = max(len(tok) for tok in self._vocab)
295

296
297
298
299
300
301
302
        # Cache special tokens for faster access.
        self._special_token_ids = self._get_special_token_ids()
        self._special_token_ids_set = set(self._special_token_ids)
        self._special_tokens = self._get_special_tokens(self._special_token_ids)
        self._special_tokens_set = set(self._special_tokens)

    def _get_special_token_ids(self) -> list[int]:
303
        return [i for i in range(len(self._vocab)) if self.tokenizer.is_special(i)]
304

305
306
307
308
309
310
    def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
        return [
            self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
            for i in all_special_ids
        ]

311
312
313
    def num_special_tokens_to_add(self) -> int:
        return len(self.encode(""))

314
315
316
317
318
319
320
321
322
323
    # the following attributes are set to fit vLLM's design and are used
    # by the structured output backends.
    @property
    def all_special_tokens(self) -> list[str]:
        return self._special_tokens

    @property
    def all_special_ids(self) -> list[int]:
        return self._special_token_ids

324
325
326
327
328
329
330
331
    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

332
333
334
335
    @property
    def pad_token_id(self) -> int:
        return self.tokenizer.pad_id

336
337
338
339
340
341
    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
342
        return self.transformers_tokenizer.vocab_size
343

344
345
346
347
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

348
349
350
351
    @property
    def max_chars_per_token(self) -> int:
        return self._max_chars_per_token

352
353
    @property
    def truncation_side(self) -> str:
354
        return self.transformers_tokenizer.truncation_side
355

356
    def _is_special_token_id(self, token_id: int) -> bool:
357
        return token_id in self._special_token_ids_set
358

359
360
361
    def __hash__(self) -> int:
        return hash(id(self))

362
363
364
    def __len__(self) -> int:
        return self.vocab_size

365
366
    def __call__(
        self,
367
        text: str | list[str],
368
        text_pair: str | None = None,
369
        add_special_tokens: bool = True,
370
        truncation: bool = False,
371
        max_length: int | None = None,
372
    ) -> "BatchEncoding":
373
374
375
376
377
378
        if text_pair is not None:
            raise ValueError(
                "`text_pair` is not supported by `MistralTokenizer.__call__`."
            )

        encoded = self.transformers_tokenizer(
379
380
381
382
383
384
            text=text,
            text_pair=text_pair,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length,
        )
385
386
387
388
389
390
391
392
393
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, revert to only call self.transformers_tokenizer(...).
        # Hack to fix wrongly added eos token, when fix will be supported the condition
        # below will be False even before the revert is done.
        if encoded["input_ids"] and encoded["input_ids"][-1] == self.eos_token_id:
            encoded["input_ids"].pop(-1)
            if attention_mask := encoded.get("attention_mask"):
                attention_mask.pop(-1)
        return encoded
394
395
396
397

    @property
    def vocab(self) -> list[str]:
        return self._vocab
398

399
    def get_vocab(self) -> dict[str, int]:
400
        return self._vocab_dict
401

402
    def get_added_vocab(self) -> dict[str, int]:
403
        # Mistral tokenizers have no added vocabulary
404
        return {}
405

406
407
408
    def encode(
        self,
        text: str,
409
410
        truncation: bool | None = None,
        max_length: int | None = None,
411
        add_special_tokens: bool = True,
412
    ) -> list[int]:
413
414
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.encode(...).
415
        encoded = self.tokenizer.encode(text, bos=add_special_tokens, eos=False)
416

417
418
419
420
        if truncation is not False and max_length is not None:
            return encoded[:max_length]
        else:
            return encoded
421

422
423
424
    def apply_chat_template(
        self,
        messages: list["ChatCompletionMessageParam"],
425
        tools: list[dict[str, Any]] | None = None,
426
427
        **kwargs,
    ) -> list[int]:
428
429
        add_generation_prompt = kwargs.pop("add_generation_prompt", False)
        continue_final_message = kwargs.get("continue_final_message", False)
430
        tokenize = kwargs.get("tokenize", True)
431
432
433
434
        padding = kwargs.get("padding", False)
        truncation = kwargs.get("truncation", False)
        max_length = kwargs.get("max_length")

Julien Denize's avatar
Julien Denize committed
435
436
437
438
439
440
        version_kwargs = {}
        # NOTE: This is for backward compatibility.
        # Transformers should be passed arguments it knows.
        if self.version >= 15:
            version_kwargs["reasoning_effort"] = kwargs.get("reasoning_effort")

441
442
443
        messages, tools = _prepare_apply_chat_template_tools_and_messages(
            messages, tools, continue_final_message, add_generation_prompt
        )
444

445
446
447
448
        return self.transformers_tokenizer.apply_chat_template(
            conversation=messages,
            tools=tools,
            continue_final_message=continue_final_message,
449
            tokenize=tokenize,
450
451
452
453
454
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=None,
            return_dict=False,
Julien Denize's avatar
Julien Denize committed
455
            **version_kwargs,
456
457
        )

458
459
460
    def decode(
        self, ids: Sequence[int] | int, skip_special_tokens: bool = False
    ) -> str:
461
462
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.decode(...).
463
464
465
        if isinstance(ids, int):
            ids = [ids]

466
467
468
        return self.transformers_tokenizer.decode(
            ids, skip_special_tokens=skip_special_tokens
        )
469

470
471
472
473
474
475
476
    def batch_decode(
        self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False
    ) -> str:
        return self.transformers_tokenizer.batch_decode(
            ids, skip_special_tokens=skip_special_tokens
        )

477
478
479
480
481
482
483
484
485
    @overload
    def convert_tokens_to_ids(self, tokens: str) -> int: ...

    @overload
    def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...

    def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
        return self.transformers_tokenizer.convert_tokens_to_ids(tokens)

486
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
487
488
489
490
491
        to_decode_special_tokens = {
            SpecialTokens.tool_calls,
            SpecialTokens.begin_think,
            SpecialTokens.end_think,
        }
492
        if self.is_tekken:
493
            assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
494
            tokens = [
495
496
                t
                for t in tokens
497
                if (t in to_decode_special_tokens or t not in self._special_tokens_set)
498
499
500
501
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
502
503
504
505
                ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
                # We filtered unwanted special tokens before
                # so we can decode the rest.
                decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
506
507
            else:
                decoded = "".join(tokens)
508
        else:
509
510
            # make sure certain special tokens like Tool calls are
            # not decoded
511
512
513
514
            assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
                self.tokenizer
            )

515
            regular_tokens: list[str] = []
516
517
            decoded_list: list[str] = []
            decoded = ""
518
519

            for token in tokens:
520
                if token in to_decode_special_tokens:
521
522
                    if regular_tokens:
                        decoded_list.append(
523
                            self.tokenizer.decode(
524
                                regular_tokens, SpecialTokenPolicy.IGNORE
525
526
                            )
                        )
527
528
529
530
531
532
533
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
534
                    self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
535
536
                )
            decoded = "".join(decoded_list)
537
538

        return decoded
539
540

    def convert_ids_to_tokens(
541
        self,
542
        ids: Sequence[int],
543
        skip_special_tokens: bool = False,
544
    ) -> list[str]:
545
546
        if not skip_special_tokens:
            return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
547

548
        non_skip_special_tokens_ids = {
549
            self.tokenizer.get_special_token(SpecialTokens.tool_calls),
550
551
552
553
554
555
        }
        if isinstance(self.instruct, InstructTokenizerV13):
            if self.instruct.BEGIN_THINK:
                non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
            if self.instruct.END_THINK:
                non_skip_special_tokens_ids.add(self.instruct.END_THINK)
556

557
558
559
560
561
        ids_kept = [
            i
            for i in ids
            if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
        ]
562

563
564
        # We filtered unwanted special tokens so we can decode the rest.
        tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
565

566
        if any("�" in t for t in tokens) and self.is_tekken:
567
568
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
569
            # See: https://github.com/vllm-project/vllm/pull/8640
570
            #      https://github.com/vllm-project/vllm/pull/9625
571
572
            # if underlying tokenizer is sentencepiece, we just add "�".
            # We filtered unwanted special tokens so we can decode the rest.
573
            tokens = [
574
                self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
575
                if token_id not in self._special_token_ids_set
576
577
                else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
                for token_id in ids_kept
578
            ]
579

580
        return tokens
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601

    @property
    def supports_grammar(self) -> bool:
        return GrammarFactory.is_supported(self.mistral)

    @cached_property
    def grammar_factory(self) -> GrammarFactory:
        if not self.supports_grammar:
            raise AttributeError(
                "This tokenizer does not support `grammar_factory`. "
                "This is only supported for tekken tokenizers with "
                "version >= 11."
            )
        # Cache grammar factory to avoid creating a llguidance tokenizer at every usage.
        return GrammarFactory(self.mistral)

    @cached_property
    def llg_tokenizer(self) -> "llguidance.LLTokenizer":
        if not self.is_tekken:
            raise ValueError("`llg_tokenizer` is only supported for Tekkenizers.")
        return from_mistral_tokenizer(self.mistral)