mistral.py 20.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from pathlib import Path
4
from typing import TYPE_CHECKING, Any, cast, overload
5

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from mistral_common.protocol.instruct.request import (
    ChatCompletionRequest as MistralChatCompletionRequest,
)
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import (
    SpecialTokenPolicy,
    SpecialTokens,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
from mistral_common.tokens.tokenizers.sentencepiece import (
    SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer

21
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
22
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
23
from vllm.logger import init_logger
24
25

from .protocol import TokenizerLike
26

27
if TYPE_CHECKING:
28
    from transformers import BatchEncoding
29

30
31
32
33
34
35
36
37
38
    try:
        # Transformers v5
        from transformers.tokenization_mistral_common import MistralCommonBackend
    except ImportError:
        # Transformers v4
        from transformers.tokenization_mistral_common import (
            MistralCommonTokenizer as MistralCommonBackend,
        )

39
40
logger = init_logger(__name__)

41

42
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
    # declared as iterables are replaced in in the instances by
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
66
        if message.get("role") == "assistant":
67
68
69
70
71
72
73
74
75
76
77
78
            tool_calls_validator = message.get("tool_calls", ().__iter__())
            validated_tool_calls = []
            while True:
                try:
                    tool_call = next(tool_calls_validator)  # type: ignore
                    validated_tool_calls.append(tool_call)
                except StopIteration:
                    break

            request.messages[i]["tool_calls"] = validated_tool_calls


79
def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
80
81
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
82
        if message.get("role") == "assistant":
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


109
110
def _prepare_apply_chat_template_tools_and_messages(
    messages: list["ChatCompletionMessageParam"],
111
    tools: list[dict[str, Any]] | None = None,
112
113
    continue_final_message: bool = False,
    add_generation_prompt: bool = False,
114
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
115
    if add_generation_prompt and continue_final_message:
116
        raise ValueError(
117
118
            "Cannot set both `add_generation_prompt` and "
            "`continue_final_message` to True."
119
        )
120

121
122
123
124
125
126
127
128
129
130
    last_message = cast(dict[str, Any], messages[-1])
    # add_generation_prompt is directly handled by the tokenizer but we
    # check if the user is trying to use it with a final assistant message
    # which is probably not what they want.
    # If add_generation_prompt is False, we don't need to check anything.
    if add_generation_prompt and last_message["role"] == "assistant":
        raise ValueError(
            "Cannot set `add_generation_prompt` to True when "
            "the last message is from the assistant. Consider "
            "using `continue_final_message` instead."
131
        )
132
133
134
135
    if continue_final_message and last_message["role"] != "assistant":
        raise ValueError(
            "Cannot set `continue_final_message` to True when "
            "the last message is not from the assistant."
136
        )
137

138
139
140
141
    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
142
143
        # Remove reasoning as unsupported by Mistral
        _ = message.pop("reasoning", None)  # type: ignore
144

145
    # The Mistral client, in comparison to the OpenAI client, requires the
146
147
    # "parameters" dict and the "description" string to be present
    # even if they are empty.
148
149
    if tools:
        for function in [
150
            tool["function"] for tool in tools if tool["type"] == "function"
151
        ]:
152
153
            if function.get("parameters") is None:
                function["parameters"] = {}
154
155
            if function.get("description") is None:
                function["description"] = ""
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        # We filter not supported arguments to avoid throwing an error.
        # TODO(juliendenize): remove this once OpenAI API is better supported by
        # `mistral-common`.
        tools_fields = set(Tool.model_fields.keys())
        function_fields = set(Function.model_fields.keys())
        for tool in tools:
            tool_keys = list(tool.keys())
            for tool_key in tool_keys:
                if tool_key not in tools_fields:
                    tool.pop(tool_key)
                    logger.warning_once(
                        f"'{tool_key}' is not supported by mistral-common for tools. "
                        "It has been poped from the tool definition."
                    )
                if tool["type"] == "function":
                    function_keys = list(tool["function"].keys())
                    for function_key in function_keys:
                        if function_key not in function_fields:
                            tool["function"].pop(function_key)
                            logger.warning_once(
                                f"'{function_key}' is not supported by mistral-common "
                                "for function tools. It has been poped from the "
                                "function definition."
                            )
                else:
                    raise ValueError("mistral-common only supports function tools.")

184
    return messages, tools
185
186


187
188
189
def validate_request_params(request: "ChatCompletionRequest"):
    if request.chat_template is not None or request.chat_template_kwargs is not None:
        raise ValueError("chat_template is not supported for Mistral tokenizers.")
190
191


192
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    assert isinstance(tokenizer, Tekkenizer), type(tokenizer)

    t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
    shift = tokenizer.num_special_tokens
    try:
        return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
    except KeyError:
        t_str = t_bytes.decode("utf-8")
        if t_str in tokenizer._special_tokens_reverse_vocab:
            return tokenizer._special_tokens_reverse_vocab[t_str]
        logger.warning(
            "Failed to convert token %s to id, replacing with <unk>", t_bytes
        )
        return tokenizer.unk_id

208

209
class MistralTokenizer(TokenizerLike):
210
211
212
213
214
215
216
217
218
219
    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "MistralTokenizer":
220
221
222
223
224
225
226
227
228
229
        try:
            # Transformers v5
            from transformers.tokenization_mistral_common import MistralCommonBackend
        except ImportError:
            # Transformers v4
            from transformers.tokenization_mistral_common import (
                MistralCommonTokenizer as MistralCommonBackend,
            )

        tokenizer = MistralCommonBackend.from_pretrained(
230
231
232
233
234
235
236
237
238
239
            path_or_repo_id,
            *args,
            mode=ValidationMode.test,
            cache_dir=download_dir,
            revision="main" if revision is None else revision,
            **kwargs,
        )

        return cls(tokenizer)

240
    def __init__(self, tokenizer: "MistralCommonBackend") -> None:
241
242
        super().__init__()

243
244
245
246
247
        self.transformers_tokenizer = tokenizer
        self.mistral = tokenizer.tokenizer
        self.instruct = self.mistral.instruct_tokenizer
        self.tokenizer = self.instruct.tokenizer

248
249
250
251
252
253
254
255
        mode = self.mistral._chat_completion_request_validator._mode
        if mode != ValidationMode.test:
            raise ValueError(
                "Mistral tokenizer must be in test mode. Make sure to "
                "set `mode='ValidationMode.test'` when creating the "
                "Mistral tokenizer."
            )

256
257
258
259
260
        _mistral_version_str = str(self.tokenizer.version.value)
        self.version: int = int(_mistral_version_str.split("v")[-1])

        self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
        self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
261
        if not (self.is_tekken or self.is_spm):
262
263
264
265
266
267
268
269
270
271
            raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")

        # Reverse order to ensure that the lowest token id is kept.
        self._vocab_dict = {
            self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
            for i in range(self.vocab_size - 1, -1, -1)
        }
        # Sort the dict for convenience
        self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))

272
273
274
275
        # Vocab sorted by token id.
        self._vocab = self.tokenizer.vocab()
        self._max_token_id = self.vocab_size - 1

276
277
278
279
280
281
282
        # Cache special tokens for faster access.
        self._special_token_ids = self._get_special_token_ids()
        self._special_token_ids_set = set(self._special_token_ids)
        self._special_tokens = self._get_special_tokens(self._special_token_ids)
        self._special_tokens_set = set(self._special_tokens)

    def _get_special_token_ids(self) -> list[int]:
283
        return [i for i in range(len(self._vocab)) if self.tokenizer.is_special(i)]
284

285
286
287
288
289
290
    def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
        return [
            self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
            for i in all_special_ids
        ]

291
292
293
    def num_special_tokens_to_add(self) -> int:
        return len(self.encode(""))

294
295
296
297
298
299
300
301
302
303
    # the following attributes are set to fit vLLM's design and are used
    # by the structured output backends.
    @property
    def all_special_tokens(self) -> list[str]:
        return self._special_tokens

    @property
    def all_special_ids(self) -> list[int]:
        return self._special_token_ids

304
305
306
307
308
309
310
311
    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

312
313
314
315
    @property
    def pad_token_id(self) -> int:
        return self.tokenizer.pad_id

316
317
318
319
320
321
    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
322
        return self.transformers_tokenizer.vocab_size
323

324
325
326
327
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

328
329
    @property
    def truncation_side(self) -> str:
330
        return self.transformers_tokenizer.truncation_side
331

332
    def _is_special_token_id(self, token_id: int) -> bool:
333
        return token_id in self._special_token_ids_set
334

335
336
337
    def __hash__(self) -> int:
        return hash(id(self))

338
339
340
    def __len__(self) -> int:
        return self.vocab_size

341
342
    def __call__(
        self,
343
        text: str | list[str],
344
        text_pair: str | None = None,
345
        add_special_tokens: bool = True,
346
        truncation: bool = False,
347
        max_length: int | None = None,
348
    ) -> "BatchEncoding":
349
350
351
352
353
354
        if text_pair is not None:
            raise ValueError(
                "`text_pair` is not supported by `MistralTokenizer.__call__`."
            )

        encoded = self.transformers_tokenizer(
355
356
357
358
359
360
            text=text,
            text_pair=text_pair,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length,
        )
361
362
363
364
365
366
367
368
369
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, revert to only call self.transformers_tokenizer(...).
        # Hack to fix wrongly added eos token, when fix will be supported the condition
        # below will be False even before the revert is done.
        if encoded["input_ids"] and encoded["input_ids"][-1] == self.eos_token_id:
            encoded["input_ids"].pop(-1)
            if attention_mask := encoded.get("attention_mask"):
                attention_mask.pop(-1)
        return encoded
370
371
372
373

    @property
    def vocab(self) -> list[str]:
        return self._vocab
374

375
    def get_vocab(self) -> dict[str, int]:
376
        return self._vocab_dict
377

378
    def get_added_vocab(self) -> dict[str, int]:
379
        # Mistral tokenizers have no added vocabulary
380
        return {}
381

382
383
384
    def encode(
        self,
        text: str,
385
386
        truncation: bool | None = None,
        max_length: int | None = None,
387
        add_special_tokens: bool = True,
388
    ) -> list[int]:
389
390
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.encode(...).
391
        encoded = self.tokenizer.encode(text, bos=add_special_tokens, eos=False)
392

393
394
395
396
        if truncation is not False and max_length is not None:
            return encoded[:max_length]
        else:
            return encoded
397

398
399
400
    def apply_chat_template(
        self,
        messages: list["ChatCompletionMessageParam"],
401
        tools: list[dict[str, Any]] | None = None,
402
403
        **kwargs,
    ) -> list[int]:
404
405
        add_generation_prompt = kwargs.pop("add_generation_prompt", False)
        continue_final_message = kwargs.get("continue_final_message", False)
406
        tokenize = kwargs.get("tokenize", True)
407
408
409
410
411
412
413
        padding = kwargs.get("padding", False)
        truncation = kwargs.get("truncation", False)
        max_length = kwargs.get("max_length")

        messages, tools = _prepare_apply_chat_template_tools_and_messages(
            messages, tools, continue_final_message, add_generation_prompt
        )
414

415
416
417
418
        return self.transformers_tokenizer.apply_chat_template(
            conversation=messages,
            tools=tools,
            continue_final_message=continue_final_message,
419
            tokenize=tokenize,
420
421
422
423
424
425
426
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=None,
            return_dict=False,
        )

427
    def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
428
429
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.decode(...).
430
431
432
        if isinstance(ids, int):
            ids = [ids]

433
434
435
        return self.transformers_tokenizer.decode(
            ids, skip_special_tokens=skip_special_tokens
        )
436

437
438
439
440
441
442
443
    def batch_decode(
        self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False
    ) -> str:
        return self.transformers_tokenizer.batch_decode(
            ids, skip_special_tokens=skip_special_tokens
        )

444
445
446
447
448
449
450
451
452
    @overload
    def convert_tokens_to_ids(self, tokens: str) -> int: ...

    @overload
    def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...

    def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
        return self.transformers_tokenizer.convert_tokens_to_ids(tokens)

453
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
454
        to_decode_special_tokens = {SpecialTokens.tool_calls}
455
        if self.is_tekken:
456
            assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
457
            tokens = [
458
459
                t
                for t in tokens
460
                if (t in to_decode_special_tokens or t not in self._special_tokens_set)
461
462
463
464
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
465
466
467
468
                ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
                # We filtered unwanted special tokens before
                # so we can decode the rest.
                decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
469
470
            else:
                decoded = "".join(tokens)
471
        else:
472
473
            # make sure certain special tokens like Tool calls are
            # not decoded
474
475
476
477
            assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
                self.tokenizer
            )

478
            regular_tokens: list[str] = []
479
480
            decoded_list: list[str] = []
            decoded = ""
481
482

            for token in tokens:
483
                if token in to_decode_special_tokens:
484
485
                    if regular_tokens:
                        decoded_list.append(
486
                            self.tokenizer.decode(
487
                                regular_tokens, SpecialTokenPolicy.IGNORE
488
489
                            )
                        )
490
491
492
493
494
495
496
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
497
                    self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
498
499
                )
            decoded = "".join(decoded_list)
500
501

        return decoded
502
503

    def convert_ids_to_tokens(
504
        self,
505
        ids: list[int],
506
        skip_special_tokens: bool = False,
507
    ) -> list[str]:
508
509
        if not skip_special_tokens:
            return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
510

511
512
513
514
515
516
517
518
        non_skip_special_tokens_ids = {
            self.tokenizer.get_control_token(SpecialTokens.tool_calls),
        }
        if isinstance(self.instruct, InstructTokenizerV13):
            if self.instruct.BEGIN_THINK:
                non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
            if self.instruct.END_THINK:
                non_skip_special_tokens_ids.add(self.instruct.END_THINK)
519

520
521
522
523
524
        ids_kept = [
            i
            for i in ids
            if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
        ]
525

526
527
        # We filtered unwanted special tokens so we can decode the rest.
        tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
528

529
        if any("�" in t for t in tokens) and self.is_tekken:
530
531
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
532
            # See: https://github.com/vllm-project/vllm/pull/8640
533
            #      https://github.com/vllm-project/vllm/pull/9625
534
535
            # if underlying tokenizer is sentencepiece, we just add "�".
            # We filtered unwanted special tokens so we can decode the rest.
536
            tokens = [
537
                self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
538
                if token_id not in self._special_token_ids_set
539
540
                else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
                for token_id in ids_kept
541
            ]
542

543
        return tokens