"tests/models/test_llava_next_video.py" did not exist on "38ef94888afc0c2bccc2f18422d2b525d7649ac3"
mistral.py 19.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, cast
5

6
from vllm.logger import init_logger
7
from vllm.transformers_utils.tokenizer_base import TokenizerBase
8

9
if TYPE_CHECKING:
10
11
12
13
14
15
    from mistral_common.protocol.instruct.request import (
        ChatCompletionRequest as MistralChatCompletionRequest,
    )
    from mistral_common.tokens.tokenizers.tekken import Tekkenizer
    from transformers.tokenization_mistral_common import (
        MistralCommonTokenizer as TransformersMistralTokenizer,
16
    )
17

18
    from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
19
    from vllm.entrypoints.openai.protocol import ChatCompletionRequest
20

21
22
logger = init_logger(__name__)

23

24
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
    # declared as iterables are replaced in in the instances by
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
48
        if message.get("role") == "assistant":
49
50
51
52
53
54
55
56
57
58
59
60
            tool_calls_validator = message.get("tool_calls", ().__iter__())
            validated_tool_calls = []
            while True:
                try:
                    tool_call = next(tool_calls_validator)  # type: ignore
                    validated_tool_calls.append(tool_call)
                except StopIteration:
                    break

            request.messages[i]["tool_calls"] = validated_tool_calls


61
def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
62
63
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
64
        if message.get("role") == "assistant":
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


91
92
def _prepare_apply_chat_template_tools_and_messages(
    messages: list["ChatCompletionMessageParam"],
93
    tools: list[dict[str, Any]] | None = None,
94
95
    continue_final_message: bool = False,
    add_generation_prompt: bool = False,
96
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
97
    if add_generation_prompt and continue_final_message:
98
        raise ValueError(
99
100
            "Cannot set both `add_generation_prompt` and "
            "`continue_final_message` to True."
101
        )
102

103
104
105
106
107
108
109
110
111
112
    last_message = cast(dict[str, Any], messages[-1])
    # add_generation_prompt is directly handled by the tokenizer but we
    # check if the user is trying to use it with a final assistant message
    # which is probably not what they want.
    # If add_generation_prompt is False, we don't need to check anything.
    if add_generation_prompt and last_message["role"] == "assistant":
        raise ValueError(
            "Cannot set `add_generation_prompt` to True when "
            "the last message is from the assistant. Consider "
            "using `continue_final_message` instead."
113
        )
114
115
116
117
    if continue_final_message and last_message["role"] != "assistant":
        raise ValueError(
            "Cannot set `continue_final_message` to True when "
            "the last message is not from the assistant."
118
        )
119

120
121
122
123
    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
124
125
        # Remove reasoning as unsupported by Mistral
        _ = message.pop("reasoning", None)  # type: ignore
126

127
    # The Mistral client, in comparison to the OpenAI client, requires the
128
129
    # "parameters" dict and the "description" string to be present
    # even if they are empty.
130
131
    if tools:
        for function in [
132
            tool["function"] for tool in tools if tool["type"] == "function"
133
        ]:
134
135
            if function.get("parameters") is None:
                function["parameters"] = {}
136
137
            if function.get("description") is None:
                function["description"] = ""
138

139
    return messages, tools
140
141


142
143
144
def validate_request_params(request: "ChatCompletionRequest"):
    if request.chat_template is not None or request.chat_template_kwargs is not None:
        raise ValueError("chat_template is not supported for Mistral tokenizers.")
145
146


147
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    from mistral_common.tokens.tokenizers.tekken import Tekkenizer

    assert isinstance(tokenizer, Tekkenizer), type(tokenizer)

    t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
    shift = tokenizer.num_special_tokens
    try:
        return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
    except KeyError:
        t_str = t_bytes.decode("utf-8")
        if t_str in tokenizer._special_tokens_reverse_vocab:
            return tokenizer._special_tokens_reverse_vocab[t_str]
        logger.warning(
            "Failed to convert token %s to id, replacing with <unk>", t_bytes
        )
        return tokenizer.unk_id

165

166
167
class MistralTokenizer(TokenizerBase):
    def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
168
        from mistral_common.protocol.instruct.validator import ValidationMode
169
        from mistral_common.tokens.tokenizers.sentencepiece import (
170
171
            SentencePieceTokenizer,
        )
172
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer
173

174
175
176
177
178
        self.transformers_tokenizer = tokenizer
        self.mistral = tokenizer.tokenizer
        self.instruct = self.mistral.instruct_tokenizer
        self.tokenizer = self.instruct.tokenizer

179
180
181
182
183
184
185
186
        mode = self.mistral._chat_completion_request_validator._mode
        if mode != ValidationMode.test:
            raise ValueError(
                "Mistral tokenizer must be in test mode. Make sure to "
                "set `mode='ValidationMode.test'` when creating the "
                "Mistral tokenizer."
            )

187
188
189
190
191
        _mistral_version_str = str(self.tokenizer.version.value)
        self.version: int = int(_mistral_version_str.split("v")[-1])

        self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
        self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
192
        if not (self.is_tekken or self.is_spm):
193
194
195
196
197
198
199
200
201
202
            raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")

        # Reverse order to ensure that the lowest token id is kept.
        self._vocab_dict = {
            self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
            for i in range(self.vocab_size - 1, -1, -1)
        }
        # Sort the dict for convenience
        self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))

203
204
205
206
207
208
        # Cache special tokens for faster access.
        self._special_token_ids = self._get_special_token_ids()
        self._special_token_ids_set = set(self._special_token_ids)
        self._special_tokens = self._get_special_tokens(self._special_token_ids)
        self._special_tokens_set = set(self._special_tokens)

209
210
        # Vocab sorted by token id.
        self._vocab = self.tokenizer._vocab
211
        self._max_token_id = self.vocab_size - 1
212
213

    @classmethod
214
    def from_pretrained(
215
        cls, path_or_repo_id: str, *, revision: str | None = None
216
    ) -> "MistralTokenizer":
217
        from mistral_common.protocol.instruct.validator import ValidationMode
218
219
        from transformers.tokenization_mistral_common import (
            MistralCommonTokenizer as TransformersMistralTokenizer,
220
221
        )

222
223
224
        str_revision = "main" if revision is None else revision
        return cls(
            TransformersMistralTokenizer.from_pretrained(
225
                path_or_repo_id, revision=str_revision, mode=ValidationMode.test
226
            )
227
        )
228

229
    def _get_special_token_ids(self) -> list[int]:
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        from mistral_common.tokens.tokenizers.sentencepiece import (
            SentencePieceTokenizer,
        )
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer

        if self.is_tekken:
            assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
            special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
        elif self.is_spm:
            assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
                self.tokenizer
            )
            special_ids = self.tokenizer._control_tokens
        else:
            raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
        return sorted(special_ids)
246

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
        from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy

        return [
            self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
            for i in all_special_ids
        ]

    # the following attributes are set to fit vLLM's design and are used
    # by the structured output backends.
    @property
    def all_special_tokens(self) -> list[str]:
        return self._special_tokens

    @property
    def all_special_ids(self) -> list[int]:
        return self._special_token_ids

265
266
267
268
269
270
271
272
    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

273
274
275
276
277
278
    @property
    def sep_token(self) -> str:
        raise NotImplementedError()

    @property
    def pad_token(self) -> str:
279
        return self.transformers_tokenizer.pad_token
280

281
282
283
284
285
286
    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
287
        return self.transformers_tokenizer.vocab_size
288

289
290
291
292
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

293
294
295
296
    @property
    def truncation_side(self) -> str:
        raise NotImplementedError()

297
    def _is_special_token_id(self, token_id: int) -> bool:
298
        return token_id in self._special_token_ids_set
299

300
301
302
    def __len__(self) -> int:
        return self.vocab_size

303
304
    def __call__(
        self,
305
306
        text: str | list[str] | list[int],
        text_pair: str | None = None,
307
308
        add_special_tokens: bool = False,
        truncation: bool = False,
309
        max_length: int | None = None,
310
    ):
311
312
313
314
315
316
        if text_pair is not None:
            raise ValueError(
                "`text_pair` is not supported by `MistralTokenizer.__call__`."
            )

        encoded = self.transformers_tokenizer(
317
318
319
320
321
322
            text=text,
            text_pair=text_pair,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length,
        )
323
324
325
326
327
328
329
330
331
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, revert to only call self.transformers_tokenizer(...).
        # Hack to fix wrongly added eos token, when fix will be supported the condition
        # below will be False even before the revert is done.
        if encoded["input_ids"] and encoded["input_ids"][-1] == self.eos_token_id:
            encoded["input_ids"].pop(-1)
            if attention_mask := encoded.get("attention_mask"):
                attention_mask.pop(-1)
        return encoded
332
333
334
335

    @property
    def vocab(self) -> list[str]:
        return self._vocab
336

337
    def get_vocab(self) -> dict[str, int]:
338
        return self._vocab_dict
339

340
    def get_added_vocab(self) -> dict[str, int]:
341
        # Mistral tokenizers have no added vocabulary
342
        return {}
343

344
345
    def encode_one(
        self,
346
        text: str,
347
        truncation: bool = False,
348
        max_length: int | None = None,
349
    ) -> list[int]:
350
        # Mistral Tokenizers should not add special tokens
351
352
353
        return self.transformers_tokenizer.encode(
            text, add_special_tokens=False, truncation=truncation, max_length=max_length
        )
354

355
356
357
    def encode(
        self,
        text: str,
358
359
360
        truncation: bool | None = None,
        max_length: int | None = None,
        add_special_tokens: bool | None = None,
361
    ) -> list[int]:
362
363
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.encode(...).
364
365
366
        encoded = self.tokenizer.encode(
            text, bos=add_special_tokens is not False, eos=False
        )
367

368
369
370
371
        if truncation is not False and max_length is not None:
            return encoded[:max_length]
        else:
            return encoded
372

373
374
375
    def apply_chat_template(
        self,
        messages: list["ChatCompletionMessageParam"],
376
        tools: list[dict[str, Any]] | None = None,
377
378
        **kwargs,
    ) -> list[int]:
379
380
381
382
383
384
385
386
387
        add_generation_prompt = kwargs.pop("add_generation_prompt", False)
        continue_final_message = kwargs.get("continue_final_message", False)
        padding = kwargs.get("padding", False)
        truncation = kwargs.get("truncation", False)
        max_length = kwargs.get("max_length")

        messages, tools = _prepare_apply_chat_template_tools_and_messages(
            messages, tools, continue_final_message, add_generation_prompt
        )
388

389
390
391
392
393
394
395
396
397
398
399
400
        return self.transformers_tokenizer.apply_chat_template(
            conversation=messages,
            tools=tools,
            continue_final_message=continue_final_message,
            tokenize=True,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=None,
            return_dict=False,
        )

401
    def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
402
403
        # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
        # is in, directly call self.transformers_tokenizer.decode(...).
404
405
406
        if isinstance(ids, int):
            ids = [ids]

407
408
409
        return self.transformers_tokenizer.decode(
            ids, skip_special_tokens=skip_special_tokens
        )
410

411
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
412
413
414
415
416
417
418
419
        from mistral_common.tokens.tokenizers.base import (
            SpecialTokenPolicy,
            SpecialTokens,
        )
        from mistral_common.tokens.tokenizers.sentencepiece import (
            SentencePieceTokenizer,
        )
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer
420

421
        to_decode_special_tokens = {SpecialTokens.tool_calls}
422
        if self.is_tekken:
423
            assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
424
            tokens = [
425
426
                t
                for t in tokens
427
                if (t in to_decode_special_tokens or t not in self._special_tokens_set)
428
429
430
431
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
432
433
434
435
                ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
                # We filtered unwanted special tokens before
                # so we can decode the rest.
                decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
436
437
            else:
                decoded = "".join(tokens)
438
        else:
439
440
            # make sure certain special tokens like Tool calls are
            # not decoded
441
442
443
444
            assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
                self.tokenizer
            )

445
            regular_tokens: list[str] = []
446
447
            decoded_list: list[str] = []
            decoded = ""
448
449

            for token in tokens:
450
                if token in to_decode_special_tokens:
451
452
                    if regular_tokens:
                        decoded_list.append(
453
                            self.tokenizer.decode(
454
                                regular_tokens, SpecialTokenPolicy.IGNORE
455
456
                            )
                        )
457
458
459
460
461
462
463
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
464
                    self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
465
466
                )
            decoded = "".join(decoded_list)
467
468

        return decoded
469
470

    def convert_ids_to_tokens(
471
        self,
472
        ids: list[int],
473
        skip_special_tokens: bool = True,
474
    ) -> list[str]:
475
476
477
478
        from mistral_common.tokens.tokenizers.base import (
            SpecialTokenPolicy,
            SpecialTokens,
        )
479
        from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
480

481
482
        if not skip_special_tokens:
            return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
483

484
485
486
487
488
489
490
491
        non_skip_special_tokens_ids = {
            self.tokenizer.get_control_token(SpecialTokens.tool_calls),
        }
        if isinstance(self.instruct, InstructTokenizerV13):
            if self.instruct.BEGIN_THINK:
                non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
            if self.instruct.END_THINK:
                non_skip_special_tokens_ids.add(self.instruct.END_THINK)
492

493
494
495
496
497
        ids_kept = [
            i
            for i in ids
            if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
        ]
498

499
500
        # We filtered unwanted special tokens so we can decode the rest.
        tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
501

502
        if any("�" in t for t in tokens) and self.is_tekken:
503
504
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
505
            # See: https://github.com/vllm-project/vllm/pull/8640
506
            #      https://github.com/vllm-project/vllm/pull/9625
507
508
            # if underlying tokenizer is sentencepiece, we just add "�".
            # We filtered unwanted special tokens so we can decode the rest.
509
            tokens = [
510
                self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
511
                if token_id not in self._special_token_ids_set
512
513
                else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
                for token_id in ids_kept
514
            ]
515

516
        return tokens