mistral.py 19.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import os
from dataclasses import dataclass
from pathlib import Path
7
from typing import TYPE_CHECKING, Any, Optional, Union, cast
8

9
import huggingface_hub
10
import regex as re
11
12
from huggingface_hub import HfApi, hf_hub_download

13
from vllm.logger import init_logger
14
from vllm.transformers_utils.tokenizer_base import TokenizerBase
15
from vllm.utils import is_list_of
16

17
if TYPE_CHECKING:
18
19
20
21
22
23
24
    # make sure `mistral_common` is lazy imported,
    # so that users who only use non-mistral models
    # will not be bothered by the dependency.
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import (
        MistralTokenizer as PublicMistralTokenizer)

25
    from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
26

27
28
logger = init_logger(__name__)

29
30
31

@dataclass
class Encoding:
32
    input_ids: Union[list[int], list[list[int]]]
33
34


35
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
    # declared as iterables are replaced in in the instances by
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
        if message.get("role") == 'assistant':
            tool_calls_validator = message.get("tool_calls", ().__iter__())
            validated_tool_calls = []
            while True:
                try:
                    tool_call = next(tool_calls_validator)  # type: ignore
                    validated_tool_calls.append(tool_call)
                except StopIteration:
                    break

            request.messages[i]["tool_calls"] = validated_tool_calls


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
        if message.get("role") == 'assistant':
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


102
103
104
105
106
107
108
def validate_request_params(request: "ChatCompletionRequest"):
    if (request.skip_special_tokens is not None
            and not request.skip_special_tokens):
        raise ValueError("skip_special_tokens=False is not supported "
                         "for Mistral tokenizers.")


109
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    repo_cache = os.path.join(
        huggingface_hub.constants.HF_HUB_CACHE,
        huggingface_hub.constants.REPO_ID_SEPARATOR.join(
            ["models", *repo_id.split("/")]))

    if revision is None:
        revision_file = os.path.join(repo_cache, "refs", "main")
        if os.path.isfile(revision_file):
            with open(revision_file) as file:
                revision = file.read()

    if revision:
        revision_dir = os.path.join(repo_cache, "snapshots", revision)
        if os.path.isdir(revision_dir):
            return os.listdir(revision_dir)

    return []


129
def find_tokenizer_file(files: list[str]):
130
131
    file_pattern = re.compile(
        r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
132
133
134

    matched_files = [file for file in files if file_pattern.match(file)]
    if len(matched_files) > 1:
135
136
137
138
        raise OSError(
            f"Found {len(matched_files)} files matching the "
            f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral "
            f"tokenizer is present in {files}.")
139
    elif len(matched_files) == 0:
140
141
142
143
        raise OSError(
            f"Found {len(matched_files)} files matching the "
            f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral "
            f"tokenizer is present in {files}.")
144
145
146
147

    return matched_files[0]


Julien Denize's avatar
Julien Denize committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def _aggregate_content(content: list) -> list[dict[str, Any]]:
    aggregated_content: list[dict[str, Any]] = []
    for chunk in content:
        if chunk.get("type"
                     ) == "text" and aggregated_content and aggregated_content[
                         -1].get("type") == "text":
            aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
        else:
            aggregated_content.append(chunk)
    if len(aggregated_content) == 1 and aggregated_content[0].get(
            "type") == "text":
        content = aggregated_content[0]["text"]
    return content


163
def make_mistral_chat_completion_request(
164
165
        messages: list["ChatCompletionMessageParam"],
        tools: Optional[list[dict[str,
166
                                  Any]]] = None) -> "ChatCompletionRequest":
167
    last_message = cast(dict[str, Any], messages[-1])
168
169
170
171
172
173
174
    if last_message["role"] == "assistant":
        last_message["prefix"] = True

    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
175
176
177
178
179
        # Remove reasoning_content as unsupported by Mistral
        _ = message.pop("reasoning_content", None)  # type: ignore

        # Convert list text content to string
        if message.get("role") in ("assistant", "tool"):
Julien Denize's avatar
Julien Denize committed
180
            content: Any = message.get("content")
181
            if isinstance(content, list):
Julien Denize's avatar
Julien Denize committed
182
183
                content = _aggregate_content(content)
            message["content"] = content
184
185
186
187
188
189
190
191

    # The Mistral client, in comparison to the OpenAI client, requires the
    # "parameters" dict to be present, even if it's empty.
    if tools:
        for function in [
                tool["function"] for tool in tools
                if tool["type"] == "function"
        ]:
192
193
            if function.get("parameters") is None:
                function["parameters"] = {}
194
195
196
197
198
199

    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    return ChatCompletionRequest(messages=messages,
                                 tools=tools)  # type: ignore[type-var]


200
class MistralTokenizer(TokenizerBase):
201

202
    def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
203
204
        self.mistral = tokenizer
        self.instruct = tokenizer.instruct_tokenizer
205
206
        _mistral_version_str = self.instruct.tokenizer.version.value
        self.version: int = int(_mistral_version_str.split("v")[-1])
207

208
        tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
209
210
        from mistral_common.tokens.tokenizers.tekken import (
            SpecialTokenPolicy, Tekkenizer)
211
        self.is_tekken = isinstance(tokenizer_, Tekkenizer)
212
213
        from mistral_common.tokens.tokenizers.sentencepiece import (
            SentencePieceTokenizer)
214
215
        self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
        if self.is_tekken:
216
            # Make sure special tokens will not raise
217
            tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
218
        elif self.is_spm:
219
            pass
220
221
        else:
            raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
222

223
        self._vocab = tokenizer_.vocab()
224
        # Convert to a dict[str, int] to match protocol, but this is a lossy
225
226
227
228
229
230
        # conversion. There may be multiple token ids that decode to the same
        # string due to partial UTF-8 byte sequences being converted to �
        self._vocab_dict = {
            token: idx
            for idx, token in enumerate(self._vocab)
        }
231
        self.tokenizer = tokenizer_
232
        self._max_token_id = self.vocab_size - 1
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

    @classmethod
    def from_pretrained(cls,
                        path_or_repo_id: str,
                        *,
                        revision: Optional[str] = None) -> "MistralTokenizer":
        if not Path(path_or_repo_id).exists():
            assert len(path_or_repo_id.split("/")) == 2, (
                "You have either provided a non-existent path: "
                "{path_or_repo_id} or an invalid HF Hub repo id.")
            tokenizer_file = cls._download_mistral_tokenizer_from_hf(
                path_or_repo_id, revision)
        elif Path(path_or_repo_id).is_dir():
            tokenizer_file_name = find_tokenizer_file(
                os.listdir(path_or_repo_id))
            tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
        else:
            assert Path(
                path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
252
            tokenizer_file = str(Path(path_or_repo_id))
253

254
255
        from mistral_common.tokens.tokenizers.mistral import (
            MistralTokenizer as PublicMistralTokenizer)
256
257
258
259
260
261
        mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
        return cls(mistral_tokenizer)

    @staticmethod
    def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
                                            revision: Optional[str]) -> str:
262
263
264
265
266
267
268
269
270
271
        try:
            hf_api = HfApi()
            files = hf_api.list_repo_files(repo_id=tokenizer_name,
                                           revision=revision)
        except ConnectionError as exc:
            files = list_local_repo_files(repo_id=tokenizer_name,
                                          revision=revision)

            if len(files) == 0:
                raise exc
272
273
274
275
276
277
278
279

        filename = find_tokenizer_file(files)

        tokenizer_file = hf_hub_download(tokenizer_name,
                                         filename=filename,
                                         revision=revision)
        return tokenizer_file

280
    # the following attributes are set to fit vLLM's design and are used
281
    # by the guided structured output backends.
282
    @property
283
    def all_special_tokens_extended(self) -> list[str]:
284
285
        from mistral_common.tokens.tokenizers.base import SpecialTokens

286
287
288
289
290
291
292
293
294
        # tekken defines its own extended special tokens list
        if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
            special_tokens = self.tokenizer.SPECIAL_TOKENS
        else:
            special_tokens = list(SpecialTokens)
        return [
            s.value if isinstance(s, SpecialTokens) else s
            for s in special_tokens
        ]
295
296

    @property
297
    def all_special_tokens(self) -> list[str]:
298
        return self.all_special_tokens_extended
299
300

    @property
301
    def all_special_ids(self) -> list[int]:
302
303
304
        return [
            self.all_special_tokens.index(t) for t in self.all_special_tokens
        ]
305
306
307
308
309
310
311
312
313

    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

314
315
316
317
318
319
320
321
    @property
    def sep_token(self) -> str:
        raise NotImplementedError()

    @property
    def pad_token(self) -> str:
        raise NotImplementedError()

322
323
324
325
326
327
328
329
    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
        return len(self._vocab)

330
331
332
333
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

334
335
336
    def __len__(self) -> int:
        return self.vocab_size

337
338
    def __call__(
        self,
339
        text: Union[str, list[str], list[int]],
340
        text_pair: Optional[str] = None,
341
342
343
344
        add_special_tokens: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
    ):
345
346
        input_ids: Union[list[int], list[list[int]]]
        # For list[str], original prompt text
347
        if is_list_of(text, str):
348
            input_ids_: list[list[int]] = []
349
            for p in text:
350
351
352
                each_input_ids = self.encode_one(p, truncation, max_length)
                input_ids_.append(each_input_ids)
            input_ids = input_ids_
353
        # For list[int], apply chat template output, already tokens.
354
355
        elif is_list_of(text, int):
            input_ids = text
356
357
        # For str, single prompt text
        else:
358
            input_ids = self.encode_one(text, truncation, max_length)
359
360
        return Encoding(input_ids=input_ids)

361
    def get_vocab(self) -> dict[str, int]:
362
363
364
        # NB: the dictionary form of the vocabulary collapses token ids that map
        # to the same string but have different bytes
        return self._vocab_dict
365

366
    def get_added_vocab(self) -> dict[str, int]:
367
        # Mistral tokenizers have no added vocabulary
368
        return {}
369

370
371
    def encode_one(
        self,
372
        text: str,
373
374
        truncation: bool = False,
        max_length: Optional[int] = None,
375
    ) -> list[int]:
376
        # Mistral Tokenizers should not add special tokens
377
        input_ids = self.encode(text)
378
379
380
381
382

        if truncation:
            input_ids = input_ids[:max_length]
        return input_ids

383
384
    def encode(self,
               text: str,
385
386
               truncation: Optional[bool] = None,
               max_length: Optional[int] = None,
387
               add_special_tokens: Optional[bool] = None) -> list[int]:
388
        # `encode` should only be used for prompt completion
389
390
        # it should never be used for chat_completion.
        # For chat completion use `apply_chat_template`
391
392
393
394
395
396
        if add_special_tokens is not None:
            return self.tokenizer.encode(text,
                                         bos=add_special_tokens,
                                         eos=add_special_tokens)
        else:
            return self.tokenizer.encode(text, bos=True, eos=False)
397
398

    def apply_chat_template(self,
399
400
401
                            messages: list["ChatCompletionMessageParam"],
                            tools: Optional[list[dict[str, Any]]] = None,
                            **kwargs) -> list[int]:
402

403
        request = make_mistral_chat_completion_request(messages, tools)
404
405
406
407
408
        encoded = self.mistral.encode_chat_completion(request)

        # encode-decode to get clean prompt
        return encoded.tokens

409
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
410
        from mistral_common.tokens.tokenizers.base import SpecialTokens
411
        if self.is_tekken:
412
413
            tokens = [
                t for t in tokens
414
415
                if (t is SpecialTokens.tool_calls
                    or t not in self.tokenizer._all_special_tokens)
416
417
418
419
420
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
                shift = self.tokenizer.num_special_tokens
421
422
423
424
425
426
427
428
429
430
431
432
433
434

                def _token_to_id(t: str):
                    t_bytes = t.encode("utf-8") \
                        if not isinstance(t, bytes) else t
                    try:
                        return shift + \
                            self.tokenizer._tekken_token2id_nospecial[t_bytes]
                    except KeyError:
                        logger.warning(
                            "Failed to convert token %s to id,"
                            " replacing with <unk>", t_bytes)
                        return self.tokenizer.unk_id

                ids = [_token_to_id(t) for t in tokens]
435
436
437
                decoded = self.tokenizer.decode(ids)
            else:
                decoded = "".join(tokens)
438
        else:
439
440
441
            # make sure certain special tokens like Tool calls are
            # not decoded
            special_tokens = {SpecialTokens.tool_calls}
442
            regular_tokens: list[str] = []
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            decoded_list = []

            for token in tokens:
                if token in special_tokens:
                    if regular_tokens:
                        decoded_list.append(
                            self.tokenizer.decode(regular_tokens))
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
457
                    self.tokenizer.decode(regular_tokens))  # type: ignore
458
459

            decoded = ''.join(decoded_list)
460
461

        return decoded
462

463
464
465
    # WARN: Outlines logits processors can overwrite this method.
    # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
    # for more.
466
    def decode(self,
467
               ids: Union[list[int], int],
468
469
470
               skip_special_tokens: bool = True) -> str:
        assert (
            skip_special_tokens
471
        ), "skip_special_tokens=False is not supported for Mistral tokenizers."
472

473
474
475
476
477
        if isinstance(ids, int):
            ids = [ids]
        return self.tokenizer.decode(ids)

    def convert_ids_to_tokens(
478
        self,
479
        ids: list[int],
480
        skip_special_tokens: bool = True,
481
    ) -> list[str]:
482
        from mistral_common.tokens.tokenizers.base import SpecialTokens
Julien Denize's avatar
Julien Denize committed
483
484
        from mistral_common.tokens.tokenizers.instruct import (
            InstructTokenizerV13)
485

486
487
488
        # TODO(Patrick) - potentially allow special tokens to not be skipped
        assert (
            skip_special_tokens
489
        ), "skip_special_tokens=False is not supported for Mistral tokenizers."
490

491
        assert self.is_tekken or self.is_spm, type(self.tokenizer)
492

493
        if self.is_tekken:
Julien Denize's avatar
Julien Denize committed
494
495
            # skip special tokens except tool call and think tokens
            non_skip_special_tokens = {
496
                self.tokenizer.get_control_token(SpecialTokens.tool_calls)
Julien Denize's avatar
Julien Denize committed
497
498
499
500
501
502
503
504
505
            }
            if isinstance(self.instruct, InstructTokenizerV13):
                if self.instruct.BEGIN_THINK:
                    non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
                if self.instruct.END_THINK:
                    non_skip_special_tokens.add(self.instruct.END_THINK)
            ids = [
                i for i in ids if i > self.tokenizer.num_special_tokens
                or i in non_skip_special_tokens
506
            ]
507

508
        tokens = [self.tokenizer.id_to_piece(id) for id in ids]
509

510
        if any("�" in t for t in tokens) and self.is_tekken:
511
512
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
513
            # See: https://github.com/vllm-project/vllm/pull/8640
514
            #      https://github.com/vllm-project/vllm/pull/9625
515
            # if underlying tokenizeir is sentencepiece, we just add "�"
516
517
            tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]

518
        return tokens