mistral.py 19.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import os
from pathlib import Path
6
from typing import TYPE_CHECKING, Any, Optional, Union, cast
7

8
import huggingface_hub
9
import regex as re
10
from huggingface_hub import HfApi, hf_hub_download
11
from transformers.tokenization_utils_base import BatchEncoding
12

13
from vllm.logger import init_logger
14
from vllm.transformers_utils.tokenizer_base import TokenizerBase
15
from vllm.utils import is_list_of
16

17
if TYPE_CHECKING:
18
19
20
21
22
23
24
    # make sure `mistral_common` is lazy imported,
    # so that users who only use non-mistral models
    # will not be bothered by the dependency.
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import (
        MistralTokenizer as PublicMistralTokenizer)

25
    from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
26

27
28
logger = init_logger(__name__)

29

30
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
    # declared as iterables are replaced in in the instances by
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
        if message.get("role") == 'assistant':
            tool_calls_validator = message.get("tool_calls", ().__iter__())
            validated_tool_calls = []
            while True:
                try:
                    tool_call = next(tool_calls_validator)  # type: ignore
                    validated_tool_calls.append(tool_call)
                except StopIteration:
                    break

            request.messages[i]["tool_calls"] = validated_tool_calls


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
        if message.get("role") == 'assistant':
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


97
98
99
100
101
102
103
def validate_request_params(request: "ChatCompletionRequest"):
    if (request.skip_special_tokens is not None
            and not request.skip_special_tokens):
        raise ValueError("skip_special_tokens=False is not supported "
                         "for Mistral tokenizers.")


104
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    repo_cache = os.path.join(
        huggingface_hub.constants.HF_HUB_CACHE,
        huggingface_hub.constants.REPO_ID_SEPARATOR.join(
            ["models", *repo_id.split("/")]))

    if revision is None:
        revision_file = os.path.join(repo_cache, "refs", "main")
        if os.path.isfile(revision_file):
            with open(revision_file) as file:
                revision = file.read()

    if revision:
        revision_dir = os.path.join(repo_cache, "snapshots", revision)
        if os.path.isdir(revision_dir):
            return os.listdir(revision_dir)

    return []


124
def find_tokenizer_file(files: list[str]):
125
126
    file_pattern = re.compile(
        r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
127
128
129

    matched_files = [file for file in files if file_pattern.match(file)]
    if len(matched_files) > 1:
130
131
132
133
        raise OSError(
            f"Found {len(matched_files)} files matching the "
            f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral "
            f"tokenizer is present in {files}.")
134
    elif len(matched_files) == 0:
135
136
137
138
        raise OSError(
            f"Found {len(matched_files)} files matching the "
            f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral "
            f"tokenizer is present in {files}.")
139
140
141
142

    return matched_files[0]


Julien Denize's avatar
Julien Denize committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def _aggregate_content(content: list) -> list[dict[str, Any]]:
    aggregated_content: list[dict[str, Any]] = []
    for chunk in content:
        if chunk.get("type"
                     ) == "text" and aggregated_content and aggregated_content[
                         -1].get("type") == "text":
            aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
        else:
            aggregated_content.append(chunk)
    if len(aggregated_content) == 1 and aggregated_content[0].get(
            "type") == "text":
        content = aggregated_content[0]["text"]
    return content


158
def make_mistral_chat_completion_request(
159
160
        messages: list["ChatCompletionMessageParam"],
        tools: Optional[list[dict[str,
161
                                  Any]]] = None) -> "ChatCompletionRequest":
162
    last_message = cast(dict[str, Any], messages[-1])
163
164
165
166
167
168
169
    if last_message["role"] == "assistant":
        last_message["prefix"] = True

    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
170
171
172
173
174
        # Remove reasoning_content as unsupported by Mistral
        _ = message.pop("reasoning_content", None)  # type: ignore

        # Convert list text content to string
        if message.get("role") in ("assistant", "tool"):
Julien Denize's avatar
Julien Denize committed
175
            content: Any = message.get("content")
176
            if isinstance(content, list):
Julien Denize's avatar
Julien Denize committed
177
178
                content = _aggregate_content(content)
            message["content"] = content
179
180

    # The Mistral client, in comparison to the OpenAI client, requires the
181
182
    # "parameters" dict and the "description" string to be present
    # even if they are empty.
183
184
185
186
187
    if tools:
        for function in [
                tool["function"] for tool in tools
                if tool["type"] == "function"
        ]:
188
189
            if function.get("parameters") is None:
                function["parameters"] = {}
190
191
            if function.get("description") is None:
                function["description"] = ""
192
193
194
195
196
197

    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    return ChatCompletionRequest(messages=messages,
                                 tools=tools)  # type: ignore[type-var]


198
class MistralTokenizer(TokenizerBase):
199

200
    def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
201
202
        self.mistral = tokenizer
        self.instruct = tokenizer.instruct_tokenizer
203
204
        _mistral_version_str = self.instruct.tokenizer.version.value
        self.version: int = int(_mistral_version_str.split("v")[-1])
205

206
        tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
207
208
        from mistral_common.tokens.tokenizers.tekken import (
            SpecialTokenPolicy, Tekkenizer)
209
        self.is_tekken = isinstance(tokenizer_, Tekkenizer)
210
211
        from mistral_common.tokens.tokenizers.sentencepiece import (
            SentencePieceTokenizer)
212
213
        self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
        if self.is_tekken:
214
            # Make sure special tokens will not raise
215
            tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
216
        elif self.is_spm:
217
            pass
218
219
        else:
            raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
220

221
        self._vocab = tokenizer_.vocab()
222
        # Convert to a dict[str, int] to match protocol, but this is a lossy
223
224
225
226
227
228
        # conversion. There may be multiple token ids that decode to the same
        # string due to partial UTF-8 byte sequences being converted to �
        self._vocab_dict = {
            token: idx
            for idx, token in enumerate(self._vocab)
        }
229
        self.tokenizer = tokenizer_
230
        self._max_token_id = self.vocab_size - 1
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

    @classmethod
    def from_pretrained(cls,
                        path_or_repo_id: str,
                        *,
                        revision: Optional[str] = None) -> "MistralTokenizer":
        if not Path(path_or_repo_id).exists():
            assert len(path_or_repo_id.split("/")) == 2, (
                "You have either provided a non-existent path: "
                "{path_or_repo_id} or an invalid HF Hub repo id.")
            tokenizer_file = cls._download_mistral_tokenizer_from_hf(
                path_or_repo_id, revision)
        elif Path(path_or_repo_id).is_dir():
            tokenizer_file_name = find_tokenizer_file(
                os.listdir(path_or_repo_id))
            tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
        else:
            assert Path(
                path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
250
            tokenizer_file = str(Path(path_or_repo_id))
251

252
253
        from mistral_common.tokens.tokenizers.mistral import (
            MistralTokenizer as PublicMistralTokenizer)
254
255
256
257
258
259
        mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
        return cls(mistral_tokenizer)

    @staticmethod
    def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
                                            revision: Optional[str]) -> str:
260
261
262
263
264
265
266
267
268
269
        try:
            hf_api = HfApi()
            files = hf_api.list_repo_files(repo_id=tokenizer_name,
                                           revision=revision)
        except ConnectionError as exc:
            files = list_local_repo_files(repo_id=tokenizer_name,
                                          revision=revision)

            if len(files) == 0:
                raise exc
270
271
272
273
274
275
276
277

        filename = find_tokenizer_file(files)

        tokenizer_file = hf_hub_download(tokenizer_name,
                                         filename=filename,
                                         revision=revision)
        return tokenizer_file

278
    # the following attributes are set to fit vLLM's design and are used
279
    # by the guided structured output backends.
280
    @property
281
    def all_special_tokens_extended(self) -> list[str]:
282
283
        from mistral_common.tokens.tokenizers.base import SpecialTokens

284
285
286
287
288
289
290
291
292
        # tekken defines its own extended special tokens list
        if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
            special_tokens = self.tokenizer.SPECIAL_TOKENS
        else:
            special_tokens = list(SpecialTokens)
        return [
            s.value if isinstance(s, SpecialTokens) else s
            for s in special_tokens
        ]
293
294

    @property
295
    def all_special_tokens(self) -> list[str]:
296
        return self.all_special_tokens_extended
297
298

    @property
299
    def all_special_ids(self) -> list[int]:
300
301
302
        return [
            self.all_special_tokens.index(t) for t in self.all_special_tokens
        ]
303
304
305
306
307
308
309
310
311

    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

312
313
314
315
316
317
318
319
    @property
    def sep_token(self) -> str:
        raise NotImplementedError()

    @property
    def pad_token(self) -> str:
        raise NotImplementedError()

320
321
322
323
324
325
326
327
    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
        return len(self._vocab)

328
329
330
331
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

332
333
334
    def __len__(self) -> int:
        return self.vocab_size

335
336
    def __call__(
        self,
337
        text: Union[str, list[str], list[int]],
338
        text_pair: Optional[str] = None,
339
340
341
342
        add_special_tokens: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
    ):
343
344
        input_ids: Union[list[int], list[list[int]]]
        # For list[str], original prompt text
345
        if is_list_of(text, str):
346
            input_ids_: list[list[int]] = []
347
            for p in text:
348
349
350
                each_input_ids = self.encode_one(p, truncation, max_length)
                input_ids_.append(each_input_ids)
            input_ids = input_ids_
351
        # For list[int], apply chat template output, already tokens.
352
353
        elif is_list_of(text, int):
            input_ids = text
354
355
        # For str, single prompt text
        else:
356
            input_ids = self.encode_one(text, truncation, max_length)
357
        return BatchEncoding({"input_ids": input_ids})
358

359
    def get_vocab(self) -> dict[str, int]:
360
361
362
        # NB: the dictionary form of the vocabulary collapses token ids that map
        # to the same string but have different bytes
        return self._vocab_dict
363

364
    def get_added_vocab(self) -> dict[str, int]:
365
        # Mistral tokenizers have no added vocabulary
366
        return {}
367

368
369
    def encode_one(
        self,
370
        text: str,
371
372
        truncation: bool = False,
        max_length: Optional[int] = None,
373
    ) -> list[int]:
374
        # Mistral Tokenizers should not add special tokens
375
        input_ids = self.encode(text)
376
377
378
379
380

        if truncation:
            input_ids = input_ids[:max_length]
        return input_ids

381
382
    def encode(self,
               text: str,
383
384
               truncation: Optional[bool] = None,
               max_length: Optional[int] = None,
385
               add_special_tokens: Optional[bool] = None) -> list[int]:
386
        # `encode` should only be used for prompt completion
387
388
        # it should never be used for chat_completion.
        # For chat completion use `apply_chat_template`
389
390
391
392
393
394
        if add_special_tokens is not None:
            return self.tokenizer.encode(text,
                                         bos=add_special_tokens,
                                         eos=add_special_tokens)
        else:
            return self.tokenizer.encode(text, bos=True, eos=False)
395
396

    def apply_chat_template(self,
397
398
399
                            messages: list["ChatCompletionMessageParam"],
                            tools: Optional[list[dict[str, Any]]] = None,
                            **kwargs) -> list[int]:
400

401
        request = make_mistral_chat_completion_request(messages, tools)
402
403
404
405
406
        encoded = self.mistral.encode_chat_completion(request)

        # encode-decode to get clean prompt
        return encoded.tokens

407
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
408
        from mistral_common.tokens.tokenizers.base import SpecialTokens
409
        if self.is_tekken:
410
411
            tokens = [
                t for t in tokens
412
413
                if (t is SpecialTokens.tool_calls
                    or t not in self.tokenizer._all_special_tokens)
414
415
416
417
418
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
                shift = self.tokenizer.num_special_tokens
419
420
421
422
423
424
425
426
427
428
429
430
431
432

                def _token_to_id(t: str):
                    t_bytes = t.encode("utf-8") \
                        if not isinstance(t, bytes) else t
                    try:
                        return shift + \
                            self.tokenizer._tekken_token2id_nospecial[t_bytes]
                    except KeyError:
                        logger.warning(
                            "Failed to convert token %s to id,"
                            " replacing with <unk>", t_bytes)
                        return self.tokenizer.unk_id

                ids = [_token_to_id(t) for t in tokens]
433
434
435
                decoded = self.tokenizer.decode(ids)
            else:
                decoded = "".join(tokens)
436
        else:
437
438
439
            # make sure certain special tokens like Tool calls are
            # not decoded
            special_tokens = {SpecialTokens.tool_calls}
440
            regular_tokens: list[str] = []
441
442
443
444
445
446
447
448
449
450
451
452
453
454
            decoded_list = []

            for token in tokens:
                if token in special_tokens:
                    if regular_tokens:
                        decoded_list.append(
                            self.tokenizer.decode(regular_tokens))
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
455
                    self.tokenizer.decode(regular_tokens))  # type: ignore
456
457

            decoded = ''.join(decoded_list)
458
459

        return decoded
460

461
462
463
    # WARN: Outlines logits processors can overwrite this method.
    # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
    # for more.
464
    def decode(self,
465
               ids: Union[list[int], int],
466
467
468
               skip_special_tokens: bool = True) -> str:
        assert (
            skip_special_tokens
469
        ), "skip_special_tokens=False is not supported for Mistral tokenizers."
470

471
472
473
474
475
        if isinstance(ids, int):
            ids = [ids]
        return self.tokenizer.decode(ids)

    def convert_ids_to_tokens(
476
        self,
477
        ids: list[int],
478
        skip_special_tokens: bool = True,
479
    ) -> list[str]:
480
        from mistral_common.tokens.tokenizers.base import SpecialTokens
Julien Denize's avatar
Julien Denize committed
481
482
        from mistral_common.tokens.tokenizers.instruct import (
            InstructTokenizerV13)
483

484
485
486
        # TODO(Patrick) - potentially allow special tokens to not be skipped
        assert (
            skip_special_tokens
487
        ), "skip_special_tokens=False is not supported for Mistral tokenizers."
488

489
        assert self.is_tekken or self.is_spm, type(self.tokenizer)
490

491
        if self.is_tekken:
Julien Denize's avatar
Julien Denize committed
492
493
            # skip special tokens except tool call and think tokens
            non_skip_special_tokens = {
494
                self.tokenizer.get_control_token(SpecialTokens.tool_calls)
Julien Denize's avatar
Julien Denize committed
495
496
497
498
499
500
501
502
503
            }
            if isinstance(self.instruct, InstructTokenizerV13):
                if self.instruct.BEGIN_THINK:
                    non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
                if self.instruct.END_THINK:
                    non_skip_special_tokens.add(self.instruct.END_THINK)
            ids = [
                i for i in ids if i > self.tokenizer.num_special_tokens
                or i in non_skip_special_tokens
504
            ]
505

506
        tokens = [self.tokenizer.id_to_piece(id) for id in ids]
507

508
        if any("�" in t for t in tokens) and self.is_tekken:
509
510
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
511
            # See: https://github.com/vllm-project/vllm/pull/8640
512
            #      https://github.com/vllm-project/vllm/pull/9625
513
            # if underlying tokenizeir is sentencepiece, we just add "�"
514
515
            tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]

516
        return tokens