mistral.py 18.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import os
from dataclasses import dataclass
from pathlib import Path
7
from typing import TYPE_CHECKING, Any, Optional, Union, cast
8

9
import huggingface_hub
10
import regex as re
11
12
from huggingface_hub import HfApi, hf_hub_download

13
from vllm.logger import init_logger
14
from vllm.transformers_utils.tokenizer_base import TokenizerBase
15
from vllm.utils import is_list_of
16

17
if TYPE_CHECKING:
18
19
20
21
22
23
24
    # make sure `mistral_common` is lazy imported,
    # so that users who only use non-mistral models
    # will not be bothered by the dependency.
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import (
        MistralTokenizer as PublicMistralTokenizer)

25
    from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
26

27
28
logger = init_logger(__name__)

29
30
31

@dataclass
class Encoding:
32
    input_ids: Union[list[int], list[list[int]]]
33
34


35
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    # SEE: https://github.com/vllm-project/vllm/pull/9951
    # Credits go to: @gcalmettes
    # NOTE: There is currently a bug in pydantic where attributes
    # declared as iterables are replaced in in the instances by
    # pydantic-core ValidatorIterator instance. In particular, this
    # affects tool_calls defined in ChatCompletionAssistantMessageParam
    # model:
    # see:
    #   - https://github.com/pydantic/pydantic/issues/9467
    # As a result, tool_calls from assistant messages are never
    # deserialized in the request object if the tool_calls iterator is
    # not consumed. This affect messages passed to the MistralTokenizer
    # since no chat template is applied and therefore the tools_calls
    # iterator is not directly consumed.
    # Issue is tracked on Pydantic side, with resolution planned for
    # v2.11 release. In the meantime, the official workaround is to
    # consume the iterator so the tool_calls are correctly deserialized
    # in the OpenAI ChatCompletionAssistantMessageParam object
    # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
    # Official Pydantic Issues:
    #   - https://github.com/pydantic/pydantic/issues/9541
    # TODO: remove when pydantic v2.11 is released
    for i, message in enumerate(request.messages):
        if message.get("role") == 'assistant':
            tool_calls_validator = message.get("tool_calls", ().__iter__())
            validated_tool_calls = []
            while True:
                try:
                    tool_call = next(tool_calls_validator)  # type: ignore
                    validated_tool_calls.append(tool_call)
                except StopIteration:
                    break

            request.messages[i]["tool_calls"] = validated_tool_calls


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
    """Truncates tool call IDs for Mistral's ID requirements."""
    for i, message in enumerate(request.messages):
        if message.get("role") == 'assistant':
            tool_calls = message.get("tool_calls", [])
            for tool_call in tool_calls:
                if len(tool_call["id"]) > 9:
                    logger.warning(
                        "Truncating tool call ID: %s to %s",
                        tool_call["id"],
                        tool_call["id"][-9:],
                    )
                    tool_call["id"] = tool_call["id"][-9:]

            request.messages[i]["tool_calls"] = tool_calls

        elif message.get("role") in {"tool_results", "tool"}:
            if "tool_call_id" in message:
                tool_call_id = message["tool_call_id"]

                if len(tool_call_id) > 9:
                    logger.warning(
                        "Truncating tool_call_id: %s to %s",
                        tool_call_id,
                        tool_call_id[-9:],
                    )
                    tool_call_id = tool_call_id[-9:]
                request.messages[i]["tool_call_id"] = tool_call_id


102
103
104
105
106
107
108
def validate_request_params(request: "ChatCompletionRequest"):
    if (request.skip_special_tokens is not None
            and not request.skip_special_tokens):
        raise ValueError("skip_special_tokens=False is not supported "
                         "for Mistral tokenizers.")


109
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    repo_cache = os.path.join(
        huggingface_hub.constants.HF_HUB_CACHE,
        huggingface_hub.constants.REPO_ID_SEPARATOR.join(
            ["models", *repo_id.split("/")]))

    if revision is None:
        revision_file = os.path.join(repo_cache, "refs", "main")
        if os.path.isfile(revision_file):
            with open(revision_file) as file:
                revision = file.read()

    if revision:
        revision_dir = os.path.join(repo_cache, "snapshots", revision)
        if os.path.isdir(revision_dir):
            return os.listdir(revision_dir)

    return []


129
def find_tokenizer_file(files: list[str]):
130
131
    file_pattern = re.compile(
        r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
132
133
134

    matched_files = [file for file in files if file_pattern.match(file)]
    if len(matched_files) > 1:
135
136
137
138
        raise OSError(
            f"Found {len(matched_files)} files matching the "
            f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral "
            f"tokenizer is present in {files}.")
139
    elif len(matched_files) == 0:
140
141
142
143
        raise OSError(
            f"Found {len(matched_files)} files matching the "
            f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral "
            f"tokenizer is present in {files}.")
144
145
146
147

    return matched_files[0]


148
def make_mistral_chat_completion_request(
149
150
        messages: list["ChatCompletionMessageParam"],
        tools: Optional[list[dict[str,
151
                                  Any]]] = None) -> "ChatCompletionRequest":
152
    last_message = cast(dict[str, Any], messages[-1])
153
154
155
156
157
158
159
    if last_message["role"] == "assistant":
        last_message["prefix"] = True

    # mistral-common requires AssistantMessage content to be string [1].
    #
    # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
    for message in messages:
160
161
162
163
164
        # Remove reasoning_content as unsupported by Mistral
        _ = message.pop("reasoning_content", None)  # type: ignore

        # Convert list text content to string
        if message.get("role") in ("assistant", "tool"):
165
166
167
168
169
170
171
172
173
174
175
176
            content = message.get("content")
            if isinstance(content, list):
                content = "\n".join(chunk.get("text") for chunk in content)
                message["content"] = content

    # The Mistral client, in comparison to the OpenAI client, requires the
    # "parameters" dict to be present, even if it's empty.
    if tools:
        for function in [
                tool["function"] for tool in tools
                if tool["type"] == "function"
        ]:
177
178
            if function.get("parameters") is None:
                function["parameters"] = {}
179
180
181
182
183
184

    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    return ChatCompletionRequest(messages=messages,
                                 tools=tools)  # type: ignore[type-var]


185
class MistralTokenizer(TokenizerBase):
186

187
    def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
188
189
        self.mistral = tokenizer
        self.instruct = tokenizer.instruct_tokenizer
190
191
        _mistral_version_str = self.instruct.tokenizer.version.value
        self.version: int = int(_mistral_version_str.split("v")[-1])
192

193
        tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
194
195
        from mistral_common.tokens.tokenizers.tekken import (
            SpecialTokenPolicy, Tekkenizer)
196
        self.is_tekken = isinstance(tokenizer_, Tekkenizer)
197
198
        from mistral_common.tokens.tokenizers.sentencepiece import (
            SentencePieceTokenizer)
199
200
        self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
        if self.is_tekken:
201
            # Make sure special tokens will not raise
202
            tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
203
        elif self.is_spm:
204
            pass
205
206
        else:
            raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
207

208
        self._vocab = tokenizer_.vocab()
209
        # Convert to a dict[str, int] to match protocol, but this is a lossy
210
211
212
213
214
215
        # conversion. There may be multiple token ids that decode to the same
        # string due to partial UTF-8 byte sequences being converted to �
        self._vocab_dict = {
            token: idx
            for idx, token in enumerate(self._vocab)
        }
216
        self.tokenizer = tokenizer_
217
        self._max_token_id = self.vocab_size - 1
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

    @classmethod
    def from_pretrained(cls,
                        path_or_repo_id: str,
                        *,
                        revision: Optional[str] = None) -> "MistralTokenizer":
        if not Path(path_or_repo_id).exists():
            assert len(path_or_repo_id.split("/")) == 2, (
                "You have either provided a non-existent path: "
                "{path_or_repo_id} or an invalid HF Hub repo id.")
            tokenizer_file = cls._download_mistral_tokenizer_from_hf(
                path_or_repo_id, revision)
        elif Path(path_or_repo_id).is_dir():
            tokenizer_file_name = find_tokenizer_file(
                os.listdir(path_or_repo_id))
            tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
        else:
            assert Path(
                path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
237
            tokenizer_file = str(Path(path_or_repo_id))
238

239
240
        from mistral_common.tokens.tokenizers.mistral import (
            MistralTokenizer as PublicMistralTokenizer)
241
242
243
244
245
246
        mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
        return cls(mistral_tokenizer)

    @staticmethod
    def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
                                            revision: Optional[str]) -> str:
247
248
249
250
251
252
253
254
255
256
        try:
            hf_api = HfApi()
            files = hf_api.list_repo_files(repo_id=tokenizer_name,
                                           revision=revision)
        except ConnectionError as exc:
            files = list_local_repo_files(repo_id=tokenizer_name,
                                          revision=revision)

            if len(files) == 0:
                raise exc
257
258
259
260
261
262
263
264

        filename = find_tokenizer_file(files)

        tokenizer_file = hf_hub_download(tokenizer_name,
                                         filename=filename,
                                         revision=revision)
        return tokenizer_file

265
    # the following attributes are set to fit vLLM's design and are used
266
    # by the guided structured output backends.
267
    @property
268
    def all_special_tokens_extended(self) -> list[str]:
269
270
        from mistral_common.tokens.tokenizers.base import SpecialTokens

271
272
273
274
275
276
277
278
279
        # tekken defines its own extended special tokens list
        if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
            special_tokens = self.tokenizer.SPECIAL_TOKENS
        else:
            special_tokens = list(SpecialTokens)
        return [
            s.value if isinstance(s, SpecialTokens) else s
            for s in special_tokens
        ]
280
281

    @property
282
    def all_special_tokens(self) -> list[str]:
283
        return self.all_special_tokens_extended
284
285

    @property
286
    def all_special_ids(self) -> list[int]:
287
288
289
        return [
            self.all_special_tokens.index(t) for t in self.all_special_tokens
        ]
290
291
292
293
294
295
296
297
298

    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_id

299
300
301
302
303
304
305
306
    @property
    def sep_token(self) -> str:
        raise NotImplementedError()

    @property
    def pad_token(self) -> str:
        raise NotImplementedError()

307
308
309
310
311
312
313
314
    @property
    def is_fast(self) -> bool:
        return True

    @property
    def vocab_size(self) -> int:
        return len(self._vocab)

315
316
317
318
    @property
    def max_token_id(self) -> int:
        return self._max_token_id

319
320
321
    def __len__(self) -> int:
        return self.vocab_size

322
323
    def __call__(
        self,
324
        text: Union[str, list[str], list[int]],
325
        text_pair: Optional[str] = None,
326
327
328
329
        add_special_tokens: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
    ):
330
331
        input_ids: Union[list[int], list[list[int]]]
        # For list[str], original prompt text
332
        if is_list_of(text, str):
333
            input_ids_: list[list[int]] = []
334
            for p in text:
335
336
337
                each_input_ids = self.encode_one(p, truncation, max_length)
                input_ids_.append(each_input_ids)
            input_ids = input_ids_
338
        # For list[int], apply chat template output, already tokens.
339
340
        elif is_list_of(text, int):
            input_ids = text
341
342
        # For str, single prompt text
        else:
343
            input_ids = self.encode_one(text, truncation, max_length)
344
345
        return Encoding(input_ids=input_ids)

346
    def get_vocab(self) -> dict[str, int]:
347
348
349
        # NB: the dictionary form of the vocabulary collapses token ids that map
        # to the same string but have different bytes
        return self._vocab_dict
350

351
    def get_added_vocab(self) -> dict[str, int]:
352
        # Mistral tokenizers have no added vocabulary
353
        return {}
354

355
356
    def encode_one(
        self,
357
        text: str,
358
359
        truncation: bool = False,
        max_length: Optional[int] = None,
360
    ) -> list[int]:
361
        # Mistral Tokenizers should not add special tokens
362
        input_ids = self.encode(text)
363
364
365
366
367

        if truncation:
            input_ids = input_ids[:max_length]
        return input_ids

368
369
    def encode(self,
               text: str,
370
371
               truncation: Optional[bool] = None,
               max_length: Optional[int] = None,
372
               add_special_tokens: Optional[bool] = None) -> list[int]:
373
        # `encode` should only be used for prompt completion
374
375
        # it should never be used for chat_completion.
        # For chat completion use `apply_chat_template`
376
377
378
379
380
381
        if add_special_tokens is not None:
            return self.tokenizer.encode(text,
                                         bos=add_special_tokens,
                                         eos=add_special_tokens)
        else:
            return self.tokenizer.encode(text, bos=True, eos=False)
382
383

    def apply_chat_template(self,
384
385
386
                            messages: list["ChatCompletionMessageParam"],
                            tools: Optional[list[dict[str, Any]]] = None,
                            **kwargs) -> list[int]:
387

388
        request = make_mistral_chat_completion_request(messages, tools)
389
390
391
392
393
        encoded = self.mistral.encode_chat_completion(request)

        # encode-decode to get clean prompt
        return encoded.tokens

394
    def convert_tokens_to_string(self, tokens: list[str]) -> str:
395
        from mistral_common.tokens.tokenizers.base import SpecialTokens
396
        if self.is_tekken:
397
398
            tokens = [
                t for t in tokens
399
400
                if (t is SpecialTokens.tool_calls
                    or t not in self.tokenizer._all_special_tokens)
401
402
403
404
405
            ]

            if any(isinstance(t, bytes) for t in tokens):
                # we need to encode and decode all tokens again
                shift = self.tokenizer.num_special_tokens
406
407
408
409
410
411
412
413
414
415
416
417
418
419

                def _token_to_id(t: str):
                    t_bytes = t.encode("utf-8") \
                        if not isinstance(t, bytes) else t
                    try:
                        return shift + \
                            self.tokenizer._tekken_token2id_nospecial[t_bytes]
                    except KeyError:
                        logger.warning(
                            "Failed to convert token %s to id,"
                            " replacing with <unk>", t_bytes)
                        return self.tokenizer.unk_id

                ids = [_token_to_id(t) for t in tokens]
420
421
422
                decoded = self.tokenizer.decode(ids)
            else:
                decoded = "".join(tokens)
423
        else:
424
425
426
            # make sure certain special tokens like Tool calls are
            # not decoded
            special_tokens = {SpecialTokens.tool_calls}
427
            regular_tokens: list[str] = []
428
429
430
431
432
433
434
435
436
437
438
439
440
441
            decoded_list = []

            for token in tokens:
                if token in special_tokens:
                    if regular_tokens:
                        decoded_list.append(
                            self.tokenizer.decode(regular_tokens))
                        regular_tokens = []
                    decoded_list.append(token)
                else:
                    regular_tokens.append(token)

            if regular_tokens:
                decoded_list.append(
442
                    self.tokenizer.decode(regular_tokens))  # type: ignore
443
444

            decoded = ''.join(decoded_list)
445
446

        return decoded
447

448
449
450
    # WARN: Outlines logits processors can overwrite this method.
    # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
    # for more.
451
    def decode(self,
452
               ids: Union[list[int], int],
453
454
455
               skip_special_tokens: bool = True) -> str:
        assert (
            skip_special_tokens
456
        ), "skip_special_tokens=False is not supported for Mistral tokenizers."
457

458
459
460
461
462
        if isinstance(ids, int):
            ids = [ids]
        return self.tokenizer.decode(ids)

    def convert_ids_to_tokens(
463
        self,
464
        ids: list[int],
465
        skip_special_tokens: bool = True,
466
    ) -> list[str]:
467
468
        from mistral_common.tokens.tokenizers.base import SpecialTokens

469
470
471
        # TODO(Patrick) - potentially allow special tokens to not be skipped
        assert (
            skip_special_tokens
472
        ), "skip_special_tokens=False is not supported for Mistral tokenizers."
473

474
        assert self.is_tekken or self.is_spm, type(self.tokenizer)
475

476
        if self.is_tekken:
477
478
479
480
481
            # skip special tokens except tool call
            ids = [
                i for i in ids if i > self.tokenizer.num_special_tokens or i ==
                self.tokenizer.get_control_token(SpecialTokens.tool_calls)
            ]
482

483
        tokens = [self.tokenizer.id_to_piece(id) for id in ids]
484

485
        if any("�" in t for t in tokens) and self.is_tekken:
486
487
            # if a decoded token contains the replacement character, then the
            # token has an incomplete UTF-8 character so we must use bytes
488
            # See: https://github.com/vllm-project/vllm/pull/8640
489
            #      https://github.com/vllm-project/vllm/pull/9625
490
            # if underlying tokenizeir is sentencepiece, we just add "�"
491
492
            tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]

493
        return tokens