mistral.py 19.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import os
from pathlib import Path
6
from typing import TYPE_CHECKING, Any, Optional, Union, cast
7

8
import huggingface_hub
9
import regex as re
10
from huggingface_hub import HfApi, hf_hub_download
11
from transformers.tokenization_utils_base import BatchEncoding
12

13
from vllm.logger import init_logger
14
from vllm.transformers_utils.tokenizer_base import TokenizerBase
15
from vllm.utils import is_list_of
16

17
if TYPE_CHECKING:
18
19
20
21
22
    # make sure `mistral_common` is lazy imported,
    # so that users who only use non-mistral models
    # will not be bothered by the dependency.
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import (
23
24
        MistralTokenizer as PublicMistralTokenizer,
    )
25

26
    from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
27

28
29
logger = init_logger(__name__)

30

31
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
    # declared as iterables are replaced in in the instances by
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
55
        if message.get("role") == "assistant":
56
57
58
59
60
61
62
63
64
65
66
67
            tool_calls_validator = message.get("tool_calls", ().__iter__())
            validated_tool_calls = []
            while True:
                try:
                    tool_call = next(tool_calls_validator)  # type: ignore
                    validated_tool_calls.append(tool_call)
                except StopIteration:
                    break

            request.messages[i]["tool_calls"] = validated_tool_calls


68
69
70
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
71
        if message.get("role") == "assistant":
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


98
def validate_request_params(request: "ChatCompletionRequest"):
99
100
101
102
    if request.skip_special_tokens is not None and not request.skip_special_tokens:
        raise ValueError(
            "skip_special_tokens=False is not supported for Mistral tokenizers."
        )
103
104


105
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
106
107
108
    repo_cache = os.path.join(
        huggingface_hub.constants.HF_HUB_CACHE,
        huggingface_hub.constants.REPO_ID_SEPARATOR.join(
109
110
111
            ["models", *repo_id.split("/")]
        ),
    )
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    if revision is None:
        revision_file = os.path.join(repo_cache, "refs", "main")
        if os.path.isfile(revision_file):
            with open(revision_file) as file:
                revision = file.read()

    if revision:
        revision_dir = os.path.join(repo_cache, "snapshots", revision)
        if os.path.isdir(revision_dir):
            return os.listdir(revision_dir)

    return []


127
def find_tokenizer_file(files: list[str]):
128
129
130
    # Accept both versioned (tokenizer.model.v3) and unversioned
    # (tokenizer.model) forms, plus tekken.json and tokenizer.mm.model
    # variants. Previous pattern only matched the versioned variants.
131
    file_pattern = re.compile(
132
133
        r"^tokenizer\.model(\.v.*)?|tekken\.json|tokenizer\.mm\.model(\.v.*)?$"
    )
134
135
136

    matched_files = [file for file in files if file_pattern.match(file)]
    if len(matched_files) > 1:
137
138
139
140
141
142
        logger.warning(
            "Multiple files matched pattern `%s`: %s. Using %s.",
            file_pattern.pattern,
            matched_files,
            matched_files[0],
        )
143
    elif len(matched_files) == 0:
144
145
146
        raise OSError(
            f"Found {len(matched_files)} files matching the "
            f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral "
147
148
            f"tokenizer is present in {files}."
        )
149
150
151
152

    return matched_files[0]


Julien Denize's avatar
Julien Denize committed
153
154
155
def _aggregate_content(content: list) -> list[dict[str, Any]]:
    aggregated_content: list[dict[str, Any]] = []
    for chunk in content:
156
157
158
159
160
        if (
            chunk.get("type") == "text"
            and aggregated_content
            and aggregated_content[-1].get("type") == "text"
        ):
Julien Denize's avatar
Julien Denize committed
161
162
163
            aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
        else:
            aggregated_content.append(chunk)
164
    if len(aggregated_content) == 1 and aggregated_content[0].get("type") == "text":
Julien Denize's avatar
Julien Denize committed
165
166
167
168
        content = aggregated_content[0]["text"]
    return content


169
def make_mistral_chat_completion_request(
170
171
172
    messages: list["ChatCompletionMessageParam"],
    tools: Optional[list[dict[str, Any]]] = None,
) -> "ChatCompletionRequest":
173
    last_message = cast(dict[str, Any], messages[-1])
174
175
176
177
178
179
180
    if last_message["role"] == "assistant":
        last_message["prefix"] = True

    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
181
182
183
184
185
        # Remove reasoning_content as unsupported by Mistral
        _ = message.pop("reasoning_content", None)  # type: ignore

        # Convert list text content to string
        if message.get("role") in ("assistant", "tool"):
Julien Denize's avatar
Julien Denize committed
186
            content: Any = message.get("content")
187
            if isinstance(content, list):
Julien Denize's avatar
Julien Denize committed
188
189
                content = _aggregate_content(content)
            message["content"] = content
190
191

    # The Mistral client, in comparison to the OpenAI client, requires the
192
193
    # "parameters" dict and the "description" string to be present
    # even if they are empty.
194
195
    if tools:
        for function in [
196
            tool["function"] for tool in tools if tool["type"] == "function"
197
        ]:
198
199
            if function.get("parameters") is None:
                function["parameters"] = {}
200
201
            if function.get("description") is None:
                function["description"] = ""
202
203
204

    from mistral_common.protocol.instruct.request import ChatCompletionRequest

205
    return ChatCompletionRequest(messages=messages, tools=tools)  # type: ignore[type-var]
206

207

208
class MistralTokenizer(TokenizerBase):
209
    def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
210
211
        self.mistral = tokenizer
        self.instruct = tokenizer.instruct_tokenizer
212
213
        _mistral_version_str = self.instruct.tokenizer.version.value
        self.version: int = int(_mistral_version_str.split("v")[-1])
214

215
        tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
216
217
218
        from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
        from mistral_common.tokens.tokenizers.tekken import Tekkenizer

219
        self.is_tekken = isinstance(tokenizer_, Tekkenizer)
220
        from mistral_common.tokens.tokenizers.sentencepiece import (
221
222
223
            SentencePieceTokenizer,
        )

224
        self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
225
226
227
        self._special_token_policy = (
            SpecialTokenPolicy.IGNORE if self.is_tekken else None
        )
228
        if not (self.is_tekken or self.is_spm):
229
            raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
230

231
        self._vocab = tokenizer_.vocab()
232
        # Convert to a dict[str, int] to match protocol, but this is a lossy
233
234
        # conversion. There may be multiple token ids that decode to the same
        # string due to partial UTF-8 byte sequences being converted to �
235
        self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)}
236
        self.tokenizer = tokenizer_
237
        self._max_token_id = self.vocab_size - 1
238
239

    @classmethod
240
241
242
    def from_pretrained(
        cls, path_or_repo_id: str, *, revision: Optional[str] = None
    ) -> "MistralTokenizer":
243
244
245
        if not Path(path_or_repo_id).exists():
            assert len(path_or_repo_id.split("/")) == 2, (
                "You have either provided a non-existent path: "
246
247
                "{path_or_repo_id} or an invalid HF Hub repo id."
            )
248
            tokenizer_file = cls._download_mistral_tokenizer_from_hf(
249
250
                path_or_repo_id, revision
            )
251
        elif Path(path_or_repo_id).is_dir():
252
            tokenizer_file_name = find_tokenizer_file(os.listdir(path_or_repo_id))
253
254
            tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
        else:
255
            assert Path(path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
256
            tokenizer_file = str(Path(path_or_repo_id))
257

258
        from mistral_common.tokens.tokenizers.mistral import (
259
260
261
            MistralTokenizer as PublicMistralTokenizer,
        )

262
263
264
265
        mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
        return cls(mistral_tokenizer)

    @staticmethod
266
267
268
    def _download_mistral_tokenizer_from_hf(
        tokenizer_name: str, revision: Optional[str]
    ) -> str:
269
270
        try:
            hf_api = HfApi()
271
            files = hf_api.list_repo_files(repo_id=tokenizer_name, revision=revision)
272
        except ConnectionError as exc:
273
            files = list_local_repo_files(repo_id=tokenizer_name, revision=revision)
274
275
276

            if len(files) == 0:
                raise exc
277
278
279

        filename = find_tokenizer_file(files)

280
281
282
        tokenizer_file = hf_hub_download(
            tokenizer_name, filename=filename, revision=revision
        )
283
284
        return tokenizer_file

285
    # the following attributes are set to fit vLLM's design and are used
286
    # by the structured output backends.
287
    @property
288
    def all_special_tokens_extended(self) -> list[str]:
289
290
        from mistral_common.tokens.tokenizers.base import SpecialTokens

291
292
293
294
295
        # tekken defines its own extended special tokens list
        if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
            special_tokens = self.tokenizer.SPECIAL_TOKENS
        else:
            special_tokens = list(SpecialTokens)
296
        return [s.value if isinstance(s, SpecialTokens) else s for s in special_tokens]
297
298

    @property
299
    def all_special_tokens(self) -> list[str]:
300
        return self.all_special_tokens_extended
301
302

    @property
303
    def all_special_ids(self) -> list[int]:
304
        return [self.all_special_tokens.index(t) for t in self.all_special_tokens]
305
306
307
308
309
310
311
312
313

    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

314
315
316
317
318
319
320
321
    @property
    def sep_token(self) -> str:
        raise NotImplementedError()

    @property
    def pad_token(self) -> str:
        raise NotImplementedError()

322
323
324
325
326
327
328
329
    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
        return len(self._vocab)

330
331
332
333
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

334
335
336
337
    @property
    def truncation_side(self) -> str:
        raise NotImplementedError()

338
339
340
    def __len__(self) -> int:
        return self.vocab_size

341
342
    def __call__(
        self,
343
        text: Union[str, list[str], list[int]],
344
        text_pair: Optional[str] = None,
345
346
347
348
        add_special_tokens: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
    ):
349
350
        input_ids: Union[list[int], list[list[int]]]
        # For list[str], original prompt text
351
        if is_list_of(text, str):
352
            input_ids_: list[list[int]] = []
353
            for p in text:
354
355
356
                each_input_ids = self.encode_one(p, truncation, max_length)
                input_ids_.append(each_input_ids)
            input_ids = input_ids_
357
        # For list[int], apply chat template output, already tokens.
358
359
        elif is_list_of(text, int):
            input_ids = text
360
361
        # For str, single prompt text
        else:
362
            input_ids = self.encode_one(text, truncation, max_length)
363
        return BatchEncoding({"input_ids": input_ids})
364

365
    def get_vocab(self) -> dict[str, int]:
366
367
368
        # NB: the dictionary form of the vocabulary collapses token ids that map
        # to the same string but have different bytes
        return self._vocab_dict
369

370
    def get_added_vocab(self) -> dict[str, int]:
371
        # Mistral tokenizers have no added vocabulary
372
        return {}
373

374
375
    def encode_one(
        self,
376
        text: str,
377
378
        truncation: bool = False,
        max_length: Optional[int] = None,
379
    ) -> list[int]:
380
        # Mistral Tokenizers should not add special tokens
381
        input_ids = self.encode(text)
382
383
384
385
386

        if truncation:
            input_ids = input_ids[:max_length]
        return input_ids

387
388
389
390
391
392
393
    def encode(
        self,
        text: str,
        truncation: Optional[bool] = None,
        max_length: Optional[int] = None,
        add_special_tokens: Optional[bool] = None,
    ) -> list[int]:
394
        # `encode` should only be used for prompt completion
395
396
        # it should never be used for chat_completion.
        # For chat completion use `apply_chat_template`
397
        if add_special_tokens is not None:
398
399
400
            return self.tokenizer.encode(
                text, bos=add_special_tokens, eos=add_special_tokens
            )
401
402
        else:
            return self.tokenizer.encode(text, bos=True, eos=False)
403

404
405
406
407
408
409
    def apply_chat_template(
        self,
        messages: list["ChatCompletionMessageParam"],
        tools: Optional[list[dict[str, Any]]] = None,
        **kwargs,
    ) -> list[int]:
410
        request = make_mistral_chat_completion_request(messages, tools)
411
412
413
414
415
        encoded = self.mistral.encode_chat_completion(request)

        # encode-decode to get clean prompt
        return encoded.tokens

416
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
417
        from mistral_common.tokens.tokenizers.base import SpecialTokens
418

419
        if self.is_tekken:
420
            tokens = [
421
422
423
424
425
426
                t
                for t in tokens
                if (
                    t is SpecialTokens.tool_calls
                    or t not in self.tokenizer._all_special_tokens
                )
427
428
429
430
431
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
                shift = self.tokenizer.num_special_tokens
432
433

                def _token_to_id(t: str):
434
                    t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
435
                    try:
436
437
438
                        return (
                            shift + self.tokenizer._tekken_token2id_nospecial[t_bytes]
                        )
439
440
                    except KeyError:
                        logger.warning(
441
442
443
                            "Failed to convert token %s to id, replacing with <unk>",
                            t_bytes,
                        )
444
445
446
                        return self.tokenizer.unk_id

                ids = [_token_to_id(t) for t in tokens]
447
                decoded = self.tokenizer.decode(ids, self._special_token_policy)
448
449
            else:
                decoded = "".join(tokens)
450
        else:
451
452
453
            # make sure certain special tokens like Tool calls are
            # not decoded
            special_tokens = {SpecialTokens.tool_calls}
454
            regular_tokens: list[str] = []
455
456
457
458
459
460
            decoded_list = []

            for token in tokens:
                if token in special_tokens:
                    if regular_tokens:
                        decoded_list.append(
461
462
463
464
                            self.tokenizer.decode(
                                regular_tokens, self._special_token_policy
                            )
                        )
465
466
467
468
469
470
471
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
472
473
                    self.tokenizer.decode(regular_tokens, self._special_token_policy)
                )
474

475
            decoded = "".join(decoded_list)
476
477

        return decoded
478

479
480
481
482
483
484
    def decode(
        self, ids: Union[list[int], int], skip_special_tokens: bool = True
    ) -> str:
        assert skip_special_tokens, (
            "skip_special_tokens=False is not supported for Mistral tokenizers."
        )
485

486
487
        if isinstance(ids, int):
            ids = [ids]
488
        return self.tokenizer.decode(ids, self._special_token_policy)
489
490

    def convert_ids_to_tokens(
491
        self,
492
        ids: list[int],
493
        skip_special_tokens: bool = True,
494
    ) -> list[str]:
495
        from mistral_common.tokens.tokenizers.base import SpecialTokens
496
        from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
497

498
        # TODO(Patrick) - potentially allow special tokens to not be skipped
499
500
501
        assert skip_special_tokens, (
            "skip_special_tokens=False is not supported for Mistral tokenizers."
        )
502

503
        assert self.is_tekken or self.is_spm, type(self.tokenizer)
504

505
        if self.is_tekken:
Julien Denize's avatar
Julien Denize committed
506
507
            # skip special tokens except tool call and think tokens
            non_skip_special_tokens = {
508
                self.tokenizer.get_control_token(SpecialTokens.tool_calls)
Julien Denize's avatar
Julien Denize committed
509
510
511
512
513
514
515
            }
            if isinstance(self.instruct, InstructTokenizerV13):
                if self.instruct.BEGIN_THINK:
                    non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
                if self.instruct.END_THINK:
                    non_skip_special_tokens.add(self.instruct.END_THINK)
            ids = [
516
517
518
                i
                for i in ids
                if i > self.tokenizer.num_special_tokens or i in non_skip_special_tokens
519
            ]
520

521
        tokens = [self.tokenizer.id_to_piece(id) for id in ids]
522

523
        if any("�" in t for t in tokens) and self.is_tekken:
524
525
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
526
            # See: https://github.com/vllm-project/vllm/pull/8640
527
            #      https://github.com/vllm-project/vllm/pull/9625
528
            # if underlying tokenizeir is sentencepiece, we just add "�"
529
530
531
532
            tokens = [
                self.tokenizer.id_to_byte_piece(id, self._special_token_policy)
                for id in ids
            ]
533

534
        return tokens