mistral.py 22.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Sequence
4
from functools import cached_property
5
from pathlib import Path
6
from typing import TYPE_CHECKING, Any, cast, overload
7

8
9
from mistral_common.guidance.grammar_factory import GrammarFactory
from mistral_common.guidance.tokenizer import from_mistral_tokenizer
10
11
12
from mistral_common.protocol.instruct.request import (
    ChatCompletionRequest as MistralChatCompletionRequest,
)
Julien Denize's avatar
Julien Denize committed
13
14
15
from mistral_common.protocol.instruct.request import (
    ReasoningEffort,
)
16
17
18
19
20
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import (
    SpecialTokenPolicy,
    SpecialTokens,
21
22
23
24
25
26
27
28
    Tokenizer,
)
from mistral_common.tokens.tokenizers.instruct import (
    InstructTokenizerBase,
    InstructTokenizerV13,
)
from mistral_common.tokens.tokenizers.mistral import (
    MistralTokenizer as MistralCommonTokenizer,
29
30
31
32
33
)
from mistral_common.tokens.tokenizers.sentencepiece import (
    SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
34
from pydantic import ValidationError
35

36
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
37
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
38
from vllm.logger import init_logger
39
from vllm.tokenizers.protocol import TokenizerLike
40

41
42
43
44
45
46
47
48
try:
    # Transformers v5
    from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
    # Transformers v4
    from transformers.tokenization_mistral_common import (
        MistralCommonTokenizer as MistralCommonBackend,
    )
49

50
if TYPE_CHECKING:
51
    import llguidance
52
    from transformers import BatchEncoding
53

54
55
logger = init_logger(__name__)

56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def _pop_unallowed_keys_and_warn(
    dictionary: dict[str, Any], allowed_keys: set[str], err_dict_name: str
):
    keys = list(dictionary.keys())
    for key in keys:
        if key not in allowed_keys:
            dictionary.pop(key)
            logger.warning_once(
                f"'{key=}' is not supported by mistral-common "
                f"for {err_dict_name}. It has been popped from the "
                "object."
            )


# TODO(juliendenize): remove this once OpenAI API is better supported by
# `mistral-common`.
def adapt_inplace_to_mistral_tool(
    tool: dict[str, Any],
) -> dict[str, Any]:
    tools_fields = set(Tool.model_fields.keys())
    function_fields = set(Function.model_fields.keys())

    # The Mistral client, in comparison to the OpenAI client, requires the
    # "parameters" dict and the "description" string to be present
    # even if they are empty.
    if function := tool.get("function"):
        if function.get("parameters") is None:
            function["parameters"] = {}
        if function.get("description") is None:
            function["description"] = ""

        _pop_unallowed_keys_and_warn(
            dictionary=function,
            allowed_keys=function_fields,
            err_dict_name="function",
        )

    _pop_unallowed_keys_and_warn(
        dictionary=tool, allowed_keys=tools_fields, err_dict_name="tools"
    )

    return tool


101
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
102
103
104
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
105
    # declared as iterables are replaced in the instances by
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
125
        if message.get("role") == "assistant":
126
            if (tool_calls_validator := message.get("tool_calls", None)) is not None:
127
                try:
128
129
130
131
132
133
134
135
                    validated_tool_calls = list(tool_calls_validator)
                except ValidationError as e:
                    raise ValueError(
                        "Validating messages' `tool_calls` raised an error. "
                        "Please ensure `tool_calls` are iterable of tool calls."
                    ) from e
            else:
                validated_tool_calls = []
136
137
138
139

            request.messages[i]["tool_calls"] = validated_tool_calls


140
def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
141
142
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
143
        if message.get("role") == "assistant":
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


170
171
def _prepare_apply_chat_template_tools_and_messages(
    messages: list["ChatCompletionMessageParam"],
172
    tools: list[dict[str, Any]] | None = None,
173
174
    continue_final_message: bool = False,
    add_generation_prompt: bool = False,
175
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
176
    if add_generation_prompt and continue_final_message:
177
        raise ValueError(
178
179
            "Cannot set both `add_generation_prompt` and "
            "`continue_final_message` to True."
180
        )
181

182
183
184
185
186
187
188
189
190
191
    last_message = cast(dict[str, Any], messages[-1])
    # add_generation_prompt is directly handled by the tokenizer but we
    # check if the user is trying to use it with a final assistant message
    # which is probably not what they want.
    # If add_generation_prompt is False, we don't need to check anything.
    if add_generation_prompt and last_message["role"] == "assistant":
        raise ValueError(
            "Cannot set `add_generation_prompt` to True when "
            "the last message is from the assistant. Consider "
            "using `continue_final_message` instead."
192
        )
193
194
195
196
    if continue_final_message and last_message["role"] != "assistant":
        raise ValueError(
            "Cannot set `continue_final_message` to True when "
            "the last message is not from the assistant."
197
        )
198

199
200
201
202
    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
203
204
        # Remove reasoning as unsupported by Mistral
        _ = message.pop("reasoning", None)  # type: ignore
205

206
207
208
209
210
    tools = (
        [adapt_inplace_to_mistral_tool(tool=tool) for tool in tools]
        if tools is not None
        else None
    )
211

212
    return messages, tools
213
214


215
216
217
def validate_request_params(request: "ChatCompletionRequest"):
    if request.chat_template is not None or request.chat_template_kwargs is not None:
        raise ValueError("chat_template is not supported for Mistral tokenizers.")
218

Julien Denize's avatar
Julien Denize committed
219
220
221
222
223
224
225
226
227
    if request.reasoning_effort and request.reasoning_effort not in list(
        ReasoningEffort
    ):
        raise ValueError(
            f"reasoning_effort={request.reasoning_effort} is not supported by "
            "Mistral models. Supported values are: "
            f"{[e.value for e in ReasoningEffort]}."
        )

228

229
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    assert isinstance(tokenizer, Tekkenizer), type(tokenizer)

    t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
    shift = tokenizer.num_special_tokens
    try:
        return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
    except KeyError:
        t_str = t_bytes.decode("utf-8")
        if t_str in tokenizer._special_tokens_reverse_vocab:
            return tokenizer._special_tokens_reverse_vocab[t_str]
        logger.warning(
            "Failed to convert token %s to id, replacing with <unk>", t_bytes
        )
        return tokenizer.unk_id

245

246
class MistralTokenizer(TokenizerLike):
247
248
    IS_MISTRAL_TOKENIZER = True  # used by vllm.utils.mistral

249
250
251
252
253
254
255
256
257
258
    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "MistralTokenizer":
259
        tokenizer = MistralCommonBackend.from_pretrained(
260
261
262
263
264
265
266
267
268
269
            path_or_repo_id,
            *args,
            mode=ValidationMode.test,
            cache_dir=download_dir,
            revision="main" if revision is None else revision,
            **kwargs,
        )

        return cls(tokenizer)

270
    def __init__(self, tokenizer: MistralCommonBackend) -> None:
271
272
        super().__init__()

273
274
275
276
        self.transformers_tokenizer: MistralCommonBackend = tokenizer
        self.mistral: MistralCommonTokenizer = tokenizer.tokenizer
        self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer
        self.tokenizer: Tokenizer = self.instruct.tokenizer
277

278
279
280
281
282
283
284
285
        mode = self.mistral._chat_completion_request_validator._mode
        if mode != ValidationMode.test:
            raise ValueError(
                "Mistral tokenizer must be in test mode. Make sure to "
                "set `mode='ValidationMode.test'` when creating the "
                "Mistral tokenizer."
            )

286
287
288
289
290
        _mistral_version_str = str(self.tokenizer.version.value)
        self.version: int = int(_mistral_version_str.split("v")[-1])

        self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
        self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
291
        if not (self.is_tekken or self.is_spm):
292
293
294
295
296
297
298
299
300
301
            raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")

        # Reverse order to ensure that the lowest token id is kept.
        self._vocab_dict = {
            self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
            for i in range(self.vocab_size - 1, -1, -1)
        }
        # Sort the dict for convenience
        self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))

302
303
304
        # Vocab sorted by token id.
        self._vocab = self.tokenizer.vocab()
        self._max_token_id = self.vocab_size - 1
305
        self._max_chars_per_token = max(len(tok) for tok in self._vocab)
306

307
308
309
310
311
312
313
        # Cache special tokens for faster access.
        self._special_token_ids = self._get_special_token_ids()
        self._special_token_ids_set = set(self._special_token_ids)
        self._special_tokens = self._get_special_tokens(self._special_token_ids)
        self._special_tokens_set = set(self._special_tokens)

    def _get_special_token_ids(self) -> list[int]:
314
        return [i for i in range(len(self._vocab)) if self.tokenizer.is_special(i)]
315

316
317
318
319
320
321
    def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
        return [
            self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
            for i in all_special_ids
        ]

322
323
324
    def num_special_tokens_to_add(self) -> int:
        return len(self.encode(""))

325
326
327
328
329
330
331
332
333
334
    # the following attributes are set to fit vLLM's design and are used
    # by the structured output backends.
    @property
    def all_special_tokens(self) -> list[str]:
        return self._special_tokens

    @property
    def all_special_ids(self) -> list[int]:
        return self._special_token_ids

335
336
337
338
339
340
341
342
    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

343
344
345
346
    @property
    def pad_token_id(self) -> int:
        return self.tokenizer.pad_id

347
348
349
350
351
352
    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
353
        return self.transformers_tokenizer.vocab_size
354

355
356
357
358
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

359
360
361
362
    @property
    def max_chars_per_token(self) -> int:
        return self._max_chars_per_token

363
364
    @property
    def truncation_side(self) -> str:
365
        return self.transformers_tokenizer.truncation_side
366

367
    def _is_special_token_id(self, token_id: int) -> bool:
368
        return token_id in self._special_token_ids_set
369

370
371
372
    def __hash__(self) -> int:
        return hash(id(self))

373
374
375
    def __len__(self) -> int:
        return self.vocab_size

376
377
    def __call__(
        self,
378
        text: str | list[str],
379
        text_pair: str | None = None,
380
        add_special_tokens: bool = True,
381
        truncation: bool = False,
382
        max_length: int | None = None,
383
    ) -> "BatchEncoding":
384
385
386
387
388
389
        if text_pair is not None:
            raise ValueError(
                "`text_pair` is not supported by `MistralTokenizer.__call__`."
            )

        encoded = self.transformers_tokenizer(
390
391
392
393
394
395
            text=text,
            text_pair=text_pair,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length,
        )
396
397
398
399
400
401
402
403
404
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, revert to only call self.transformers_tokenizer(...).
        # Hack to fix wrongly added eos token, when fix will be supported the condition
        # below will be False even before the revert is done.
        if encoded["input_ids"] and encoded["input_ids"][-1] == self.eos_token_id:
            encoded["input_ids"].pop(-1)
            if attention_mask := encoded.get("attention_mask"):
                attention_mask.pop(-1)
        return encoded
405
406
407
408

    @property
    def vocab(self) -> list[str]:
        return self._vocab
409

410
    def get_vocab(self) -> dict[str, int]:
411
        return self._vocab_dict
412

413
    def get_added_vocab(self) -> dict[str, int]:
414
        # Mistral tokenizers have no added vocabulary
415
        return {}
416

417
418
419
    def encode(
        self,
        text: str,
420
421
        truncation: bool | None = None,
        max_length: int | None = None,
422
        add_special_tokens: bool = True,
423
    ) -> list[int]:
424
425
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.encode(...).
426
        encoded = self.tokenizer.encode(text, bos=add_special_tokens, eos=False)
427

428
429
430
431
        if truncation is not False and max_length is not None:
            return encoded[:max_length]
        else:
            return encoded
432

433
434
435
    def apply_chat_template(
        self,
        messages: list["ChatCompletionMessageParam"],
436
        tools: list[dict[str, Any]] | None = None,
437
438
        **kwargs,
    ) -> list[int]:
439
440
        add_generation_prompt = kwargs.pop("add_generation_prompt", False)
        continue_final_message = kwargs.get("continue_final_message", False)
441
        tokenize = kwargs.get("tokenize", True)
442
443
444
445
        padding = kwargs.get("padding", False)
        truncation = kwargs.get("truncation", False)
        max_length = kwargs.get("max_length")

Julien Denize's avatar
Julien Denize committed
446
447
448
449
450
451
        version_kwargs = {}
        # NOTE: This is for backward compatibility.
        # Transformers should be passed arguments it knows.
        if self.version >= 15:
            version_kwargs["reasoning_effort"] = kwargs.get("reasoning_effort")

452
453
454
        messages, tools = _prepare_apply_chat_template_tools_and_messages(
            messages, tools, continue_final_message, add_generation_prompt
        )
455

456
457
458
459
        return self.transformers_tokenizer.apply_chat_template(
            conversation=messages,
            tools=tools,
            continue_final_message=continue_final_message,
460
            tokenize=tokenize,
461
462
463
464
465
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=None,
            return_dict=False,
Julien Denize's avatar
Julien Denize committed
466
            **version_kwargs,
467
468
        )

469
470
471
    def decode(
        self, ids: Sequence[int] | int, skip_special_tokens: bool = False
    ) -> str:
472
473
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.decode(...).
474
475
476
        if isinstance(ids, int):
            ids = [ids]

477
478
479
        return self.transformers_tokenizer.decode(
            ids, skip_special_tokens=skip_special_tokens
        )
480

481
482
483
484
485
486
487
    def batch_decode(
        self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False
    ) -> str:
        return self.transformers_tokenizer.batch_decode(
            ids, skip_special_tokens=skip_special_tokens
        )

488
489
490
491
492
493
494
495
496
    @overload
    def convert_tokens_to_ids(self, tokens: str) -> int: ...

    @overload
    def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...

    def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
        return self.transformers_tokenizer.convert_tokens_to_ids(tokens)

497
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
498
499
500
501
502
        to_decode_special_tokens = {
            SpecialTokens.tool_calls,
            SpecialTokens.begin_think,
            SpecialTokens.end_think,
        }
503
        if self.is_tekken:
504
            assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
505
            tokens = [
506
507
                t
                for t in tokens
508
                if (t in to_decode_special_tokens or t not in self._special_tokens_set)
509
510
511
512
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
513
514
515
516
                ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
                # We filtered unwanted special tokens before
                # so we can decode the rest.
                decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
517
518
            else:
                decoded = "".join(tokens)
519
        else:
520
521
            # make sure certain special tokens like Tool calls are
            # not decoded
522
523
524
525
            assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
                self.tokenizer
            )

526
            regular_tokens: list[str] = []
527
528
            decoded_list: list[str] = []
            decoded = ""
529
530

            for token in tokens:
531
                if token in to_decode_special_tokens:
532
533
                    if regular_tokens:
                        decoded_list.append(
534
                            self.tokenizer.decode(
535
                                regular_tokens, SpecialTokenPolicy.IGNORE
536
537
                            )
                        )
538
539
540
541
542
543
544
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
545
                    self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
546
547
                )
            decoded = "".join(decoded_list)
548
549

        return decoded
550
551

    def convert_ids_to_tokens(
552
        self,
553
        ids: Sequence[int],
554
        skip_special_tokens: bool = False,
555
    ) -> list[str]:
556
557
        if not skip_special_tokens:
            return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
558

559
        non_skip_special_tokens_ids = {
560
            self.tokenizer.get_special_token(SpecialTokens.tool_calls),
561
562
563
564
565
566
        }
        if isinstance(self.instruct, InstructTokenizerV13):
            if self.instruct.BEGIN_THINK:
                non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
            if self.instruct.END_THINK:
                non_skip_special_tokens_ids.add(self.instruct.END_THINK)
567

568
569
570
571
572
        ids_kept = [
            i
            for i in ids
            if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
        ]
573

574
575
        # We filtered unwanted special tokens so we can decode the rest.
        tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
576

577
        if any("�" in t for t in tokens) and self.is_tekken:
578
579
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
580
            # See: https://github.com/vllm-project/vllm/pull/8640
581
            #      https://github.com/vllm-project/vllm/pull/9625
582
583
            # if underlying tokenizer is sentencepiece, we just add "�".
            # We filtered unwanted special tokens so we can decode the rest.
584
            tokens = [
585
                self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
586
                if token_id not in self._special_token_ids_set
587
588
                else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
                for token_id in ids_kept
589
            ]
590

591
        return tokens
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

    @property
    def supports_grammar(self) -> bool:
        return GrammarFactory.is_supported(self.mistral)

    @cached_property
    def grammar_factory(self) -> GrammarFactory:
        if not self.supports_grammar:
            raise AttributeError(
                "This tokenizer does not support `grammar_factory`. "
                "This is only supported for tekken tokenizers with "
                "version >= 11."
            )
        # Cache grammar factory to avoid creating a llguidance tokenizer at every usage.
        return GrammarFactory(self.mistral)

    @cached_property
    def llg_tokenizer(self) -> "llguidance.LLTokenizer":
        if not self.is_tekken:
            raise ValueError("`llg_tokenizer` is only supported for Tekkenizers.")
        return from_mistral_tokenizer(self.mistral)