mistral.py 19.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, cast
5

6
from vllm.logger import init_logger
7
8

from .protocol import TokenizerLike
9

10
if TYPE_CHECKING:
11
12
13
14
15
16
    from mistral_common.protocol.instruct.request import (
        ChatCompletionRequest as MistralChatCompletionRequest,
    )
    from mistral_common.tokens.tokenizers.tekken import Tekkenizer
    from transformers.tokenization_mistral_common import (
        MistralCommonTokenizer as TransformersMistralTokenizer,
17
    )
18

19
    from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
20
    from vllm.entrypoints.openai.protocol import ChatCompletionRequest
21

22
23
logger = init_logger(__name__)

24

25
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
    # declared as iterables are replaced in in the instances by
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
49
        if message.get("role") == "assistant":
50
51
52
53
54
55
56
57
58
59
60
61
            tool_calls_validator = message.get("tool_calls", ().__iter__())
            validated_tool_calls = []
            while True:
                try:
                    tool_call = next(tool_calls_validator)  # type: ignore
                    validated_tool_calls.append(tool_call)
                except StopIteration:
                    break

            request.messages[i]["tool_calls"] = validated_tool_calls


62
def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
63
64
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
65
        if message.get("role") == "assistant":
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


92
93
def _prepare_apply_chat_template_tools_and_messages(
    messages: list["ChatCompletionMessageParam"],
94
    tools: list[dict[str, Any]] | None = None,
95
96
    continue_final_message: bool = False,
    add_generation_prompt: bool = False,
97
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
98
    if add_generation_prompt and continue_final_message:
99
        raise ValueError(
100
101
            "Cannot set both `add_generation_prompt` and "
            "`continue_final_message` to True."
102
        )
103

104
105
106
107
108
109
110
111
112
113
    last_message = cast(dict[str, Any], messages[-1])
    # add_generation_prompt is directly handled by the tokenizer but we
    # check if the user is trying to use it with a final assistant message
    # which is probably not what they want.
    # If add_generation_prompt is False, we don't need to check anything.
    if add_generation_prompt and last_message["role"] == "assistant":
        raise ValueError(
            "Cannot set `add_generation_prompt` to True when "
            "the last message is from the assistant. Consider "
            "using `continue_final_message` instead."
114
        )
115
116
117
118
    if continue_final_message and last_message["role"] != "assistant":
        raise ValueError(
            "Cannot set `continue_final_message` to True when "
            "the last message is not from the assistant."
119
        )
120

121
122
123
124
    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
125
126
        # Remove reasoning as unsupported by Mistral
        _ = message.pop("reasoning", None)  # type: ignore
127

128
    # The Mistral client, in comparison to the OpenAI client, requires the
129
130
    # "parameters" dict and the "description" string to be present
    # even if they are empty.
131
132
    if tools:
        for function in [
133
            tool["function"] for tool in tools if tool["type"] == "function"
134
        ]:
135
136
            if function.get("parameters") is None:
                function["parameters"] = {}
137
138
            if function.get("description") is None:
                function["description"] = ""
139

140
    return messages, tools
141
142


143
144
145
def validate_request_params(request: "ChatCompletionRequest"):
    if request.chat_template is not None or request.chat_template_kwargs is not None:
        raise ValueError("chat_template is not supported for Mistral tokenizers.")
146
147


148
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    from mistral_common.tokens.tokenizers.tekken import Tekkenizer

    assert isinstance(tokenizer, Tekkenizer), type(tokenizer)

    t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
    shift = tokenizer.num_special_tokens
    try:
        return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
    except KeyError:
        t_str = t_bytes.decode("utf-8")
        if t_str in tokenizer._special_tokens_reverse_vocab:
            return tokenizer._special_tokens_reverse_vocab[t_str]
        logger.warning(
            "Failed to convert token %s to id, replacing with <unk>", t_bytes
        )
        return tokenizer.unk_id

166

167
class MistralTokenizer(TokenizerLike):
168
    def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
169
        from mistral_common.protocol.instruct.validator import ValidationMode
170
        from mistral_common.tokens.tokenizers.sentencepiece import (
171
172
            SentencePieceTokenizer,
        )
173
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer
174

175
176
177
178
179
        self.transformers_tokenizer = tokenizer
        self.mistral = tokenizer.tokenizer
        self.instruct = self.mistral.instruct_tokenizer
        self.tokenizer = self.instruct.tokenizer

180
181
182
183
184
185
186
187
        mode = self.mistral._chat_completion_request_validator._mode
        if mode != ValidationMode.test:
            raise ValueError(
                "Mistral tokenizer must be in test mode. Make sure to "
                "set `mode='ValidationMode.test'` when creating the "
                "Mistral tokenizer."
            )

188
189
190
191
192
        _mistral_version_str = str(self.tokenizer.version.value)
        self.version: int = int(_mistral_version_str.split("v")[-1])

        self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
        self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
193
        if not (self.is_tekken or self.is_spm):
194
195
196
197
198
199
200
201
202
203
            raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")

        # Reverse order to ensure that the lowest token id is kept.
        self._vocab_dict = {
            self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
            for i in range(self.vocab_size - 1, -1, -1)
        }
        # Sort the dict for convenience
        self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))

204
205
206
207
208
209
        # Cache special tokens for faster access.
        self._special_token_ids = self._get_special_token_ids()
        self._special_token_ids_set = set(self._special_token_ids)
        self._special_tokens = self._get_special_tokens(self._special_token_ids)
        self._special_tokens_set = set(self._special_tokens)

210
211
        # Vocab sorted by token id.
        self._vocab = self.tokenizer._vocab
212
        self._max_token_id = self.vocab_size - 1
213
214

    @classmethod
215
    def from_pretrained(
216
        cls, path_or_repo_id: str, *, revision: str | None = None
217
    ) -> "MistralTokenizer":
218
        from mistral_common.protocol.instruct.validator import ValidationMode
219
220
        from transformers.tokenization_mistral_common import (
            MistralCommonTokenizer as TransformersMistralTokenizer,
221
222
        )

223
224
225
        str_revision = "main" if revision is None else revision
        return cls(
            TransformersMistralTokenizer.from_pretrained(
226
                path_or_repo_id, revision=str_revision, mode=ValidationMode.test
227
            )
228
        )
229

230
    def _get_special_token_ids(self) -> list[int]:
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        from mistral_common.tokens.tokenizers.sentencepiece import (
            SentencePieceTokenizer,
        )
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer

        if self.is_tekken:
            assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
            special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
        elif self.is_spm:
            assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
                self.tokenizer
            )
            special_ids = self.tokenizer._control_tokens
        else:
            raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
        return sorted(special_ids)
247

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
        from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy

        return [
            self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
            for i in all_special_ids
        ]

    # the following attributes are set to fit vLLM's design and are used
    # by the structured output backends.
    @property
    def all_special_tokens(self) -> list[str]:
        return self._special_tokens

    @property
    def all_special_ids(self) -> list[int]:
        return self._special_token_ids

266
267
268
269
270
271
272
273
274
275
276
277
278
279
    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
280
        return self.transformers_tokenizer.vocab_size
281

282
283
284
285
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

286
287
    @property
    def truncation_side(self) -> str:
288
        return self.transformers_tokenizer.truncation_side
289

290
    def _is_special_token_id(self, token_id: int) -> bool:
291
        return token_id in self._special_token_ids_set
292

293
294
295
    def __hash__(self) -> int:
        return hash(id(self))

296
297
298
    def __len__(self) -> int:
        return self.vocab_size

299
300
    def __call__(
        self,
301
302
        text: str | list[str] | list[int],
        text_pair: str | None = None,
303
304
        add_special_tokens: bool = False,
        truncation: bool = False,
305
        max_length: int | None = None,
306
    ):
307
308
309
310
311
312
        if text_pair is not None:
            raise ValueError(
                "`text_pair` is not supported by `MistralTokenizer.__call__`."
            )

        encoded = self.transformers_tokenizer(
313
314
315
316
317
318
            text=text,
            text_pair=text_pair,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length,
        )
319
320
321
322
323
324
325
326
327
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, revert to only call self.transformers_tokenizer(...).
        # Hack to fix wrongly added eos token, when fix will be supported the condition
        # below will be False even before the revert is done.
        if encoded["input_ids"] and encoded["input_ids"][-1] == self.eos_token_id:
            encoded["input_ids"].pop(-1)
            if attention_mask := encoded.get("attention_mask"):
                attention_mask.pop(-1)
        return encoded
328
329
330
331

    @property
    def vocab(self) -> list[str]:
        return self._vocab
332

333
    def get_vocab(self) -> dict[str, int]:
334
        return self._vocab_dict
335

336
    def get_added_vocab(self) -> dict[str, int]:
337
        # Mistral tokenizers have no added vocabulary
338
        return {}
339

340
341
342
    def encode(
        self,
        text: str,
343
344
345
        truncation: bool | None = None,
        max_length: int | None = None,
        add_special_tokens: bool | None = None,
346
    ) -> list[int]:
347
348
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.encode(...).
349
350
351
        encoded = self.tokenizer.encode(
            text, bos=add_special_tokens is not False, eos=False
        )
352

353
354
355
356
        if truncation is not False and max_length is not None:
            return encoded[:max_length]
        else:
            return encoded
357

358
359
360
    def apply_chat_template(
        self,
        messages: list["ChatCompletionMessageParam"],
361
        tools: list[dict[str, Any]] | None = None,
362
363
        **kwargs,
    ) -> list[int]:
364
365
366
367
368
369
370
371
372
        add_generation_prompt = kwargs.pop("add_generation_prompt", False)
        continue_final_message = kwargs.get("continue_final_message", False)
        padding = kwargs.get("padding", False)
        truncation = kwargs.get("truncation", False)
        max_length = kwargs.get("max_length")

        messages, tools = _prepare_apply_chat_template_tools_and_messages(
            messages, tools, continue_final_message, add_generation_prompt
        )
373

374
375
376
377
378
379
380
381
382
383
384
385
        return self.transformers_tokenizer.apply_chat_template(
            conversation=messages,
            tools=tools,
            continue_final_message=continue_final_message,
            tokenize=True,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=None,
            return_dict=False,
        )

386
    def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
387
388
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.decode(...).
389
390
391
        if isinstance(ids, int):
            ids = [ids]

392
393
394
        return self.transformers_tokenizer.decode(
            ids, skip_special_tokens=skip_special_tokens
        )
395

396
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
397
398
399
400
401
402
403
404
        from mistral_common.tokens.tokenizers.base import (
            SpecialTokenPolicy,
            SpecialTokens,
        )
        from mistral_common.tokens.tokenizers.sentencepiece import (
            SentencePieceTokenizer,
        )
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer
405

406
        to_decode_special_tokens = {SpecialTokens.tool_calls}
407
        if self.is_tekken:
408
            assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
409
            tokens = [
410
411
                t
                for t in tokens
412
                if (t in to_decode_special_tokens or t not in self._special_tokens_set)
413
414
415
416
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
417
418
419
420
                ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
                # We filtered unwanted special tokens before
                # so we can decode the rest.
                decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
421
422
            else:
                decoded = "".join(tokens)
423
        else:
424
425
            # make sure certain special tokens like Tool calls are
            # not decoded
426
427
428
429
            assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
                self.tokenizer
            )

430
            regular_tokens: list[str] = []
431
432
            decoded_list: list[str] = []
            decoded = ""
433
434

            for token in tokens:
435
                if token in to_decode_special_tokens:
436
437
                    if regular_tokens:
                        decoded_list.append(
438
                            self.tokenizer.decode(
439
                                regular_tokens, SpecialTokenPolicy.IGNORE
440
441
                            )
                        )
442
443
444
445
446
447
448
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
449
                    self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
450
451
                )
            decoded = "".join(decoded_list)
452
453

        return decoded
454
455

    def convert_ids_to_tokens(
456
        self,
457
        ids: list[int],
458
        skip_special_tokens: bool = True,
459
    ) -> list[str]:
460
461
462
463
        from mistral_common.tokens.tokenizers.base import (
            SpecialTokenPolicy,
            SpecialTokens,
        )
464
        from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
465

466
467
        if not skip_special_tokens:
            return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
468

469
470
471
472
473
474
475
476
        non_skip_special_tokens_ids = {
            self.tokenizer.get_control_token(SpecialTokens.tool_calls),
        }
        if isinstance(self.instruct, InstructTokenizerV13):
            if self.instruct.BEGIN_THINK:
                non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
            if self.instruct.END_THINK:
                non_skip_special_tokens_ids.add(self.instruct.END_THINK)
477

478
479
480
481
482
        ids_kept = [
            i
            for i in ids
            if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
        ]
483

484
485
        # We filtered unwanted special tokens so we can decode the rest.
        tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
486

487
        if any("�" in t for t in tokens) and self.is_tekken:
488
489
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
490
            # See: https://github.com/vllm-project/vllm/pull/8640
491
            #      https://github.com/vllm-project/vllm/pull/9625
492
493
            # if underlying tokenizer is sentencepiece, we just add "�".
            # We filtered unwanted special tokens so we can decode the rest.
494
            tokens = [
495
                self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
496
                if token_id not in self._special_token_ids_set
497
498
                else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
                for token_id in ids_kept
499
            ]
500

501
        return tokens