tokenizer.py 9.88 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import copy
6
import os
7
import warnings
8
from functools import lru_cache
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, TypeAlias
11

12
import huggingface_hub
13
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
14
from typing_extensions import assert_never
15

16
from vllm import envs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.logger import init_logger
18
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
19
from vllm.transformers_utils.tokenizers import MistralTokenizer
20
from vllm.transformers_utils.utils import check_gguf_file
21

22
23
if TYPE_CHECKING:
    from vllm.config import ModelConfig
24
25
26
27
28
29
    from vllm.lora.request import LoRARequest
    from vllm.transformers_utils.tokenizer_base import TokenizerBase
else:
    ModelConfig = Any
    LoRARequest = Any
    TokenizerBase = Any
30

31
32
logger = init_logger(__name__)

33
AnyTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast | TokenizerBase
34

35

36
37
38
39
def decode_tokens(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
40
    skip_special_tokens: bool | None = None,
41
42
43
) -> str:
    """
    Backend-agnostic equivalent of HF's
44
    `tokenizer.decode(token_ids, ...)`.
45

46
    `skip_special_tokens=None` means to use the backend's default
47
    settings.
48
    """
49
    if skip_special_tokens is not None:
50
        return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
51

52
    return tokenizer.decode(token_ids)
53
54


55
56
57
58
def encode_tokens(
    tokenizer: AnyTokenizer,
    text: str,
    *,
59
60
61
    truncation: bool | None = None,
    max_length: int | None = None,
    add_special_tokens: bool | None = None,
62
63
64
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
65
    `tokenizer.encode(text, ...)`.
66

67
    `add_special_tokens=None` means to use the backend's default
68
    settings.
69
    """
70
71
72
73
74
75
76
77

    kw_args: dict[str, Any] = {}
    if max_length is not None:
        kw_args["max_length"] = max_length

    if truncation is not None:
        kw_args["truncation"] = truncation

78
    if add_special_tokens is not None:
79
        kw_args["add_special_tokens"] = add_special_tokens
80

81
    return tokenizer.encode(text, **kw_args)
82
83


84
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
85
    """
86
    By default, transformers will recompute multiple tokenizer properties
87
88
89
90
    each time they are called, leading to a significant slowdown.
    This proxy caches these properties for faster access.
    """
    cached_tokenizer = copy.copy(tokenizer)
91

92
93
    tokenizer_all_special_ids = tokenizer.all_special_ids
    tokenizer_all_special_tokens = tokenizer.all_special_tokens
94
    tokenizer_all_special_tokens_extended = tokenizer.all_special_tokens_extended
95
    tokenizer_vocab = tokenizer.get_vocab()
96
    tokenizer_len = len(tokenizer)
97

98
    max_token_id = max(tokenizer_vocab.values())
99
100
101
102
103
104
105
    # Some tokenizers (e.g., QwenTokenizer) have special tokens that
    # are added and included in the implementation of the vocab_size
    # property, but not in get_vocab(); if there is an implementation
    # of vocab size, we should take the greater value.
    if hasattr(tokenizer, "vocab_size"):
        with contextlib.suppress(NotImplementedError):
            max_token_id = max(max_token_id, tokenizer.vocab_size)
106

107
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
108
        @property
109
        def all_special_ids(self) -> list[int]:
110
111
112
            return tokenizer_all_special_ids

        @property
113
        def all_special_tokens(self) -> list[str]:
114
115
116
            return tokenizer_all_special_tokens

        @property
117
        def all_special_tokens_extended(self) -> list[str]:
118
119
            return tokenizer_all_special_tokens_extended

120
        @property
121
        def max_token_id(self) -> int:
122
123
            return max_token_id

124
        def get_vocab(self) -> dict[str, int]:
125
126
            return tokenizer_vocab

127
        def __len__(self) -> int:
128
129
            return tokenizer_len

130
        def __reduce__(self):
131
            return get_cached_tokenizer, (tokenizer,)
132

133
134
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

135
136
    cached_tokenizer.__class__ = CachedTokenizer
    return cached_tokenizer
137
138


139
def get_tokenizer(
140
    tokenizer_name: str | Path,
141
    *args,
142
    tokenizer_mode: str = "auto",
143
    trust_remote_code: bool = False,
144
145
    revision: str | None = None,
    download_dir: str | None = None,
146
    **kwargs,
147
) -> AnyTokenizer:
148
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
149
    if envs.VLLM_USE_MODELSCOPE:
150
151
152
153
154
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

155
156
157
        # avoid circuit import
        from vllm.model_executor.model_loader.weight_utils import get_lock

158
159
        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
160
161
162
163
164
165
166
167
168
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=tokenizer_name,
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
169
170
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
                )
171
                tokenizer_name = tokenizer_path
172

173
174
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
175
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
176
177
        kwargs["use_fast"] = False

178
179
180
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

181
    # Separate model folder from file path for GGUF models
182
    is_gguf = check_gguf_file(tokenizer_name)
183
184
185
186
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

187
188
189
190
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
191
            "It is strongly recommended to run mistral models with "
192
            '`--tokenizer-mode "mistral"` to ensure correct '
193
            "encoding and decoding.",
194
            FutureWarning,
195
196
            stacklevel=2,
        )
197
198

    tokenizer: AnyTokenizer
199
    if tokenizer_mode == "mistral":
200
201
202
        tokenizer = MistralTokenizer.from_pretrained(
            str(tokenizer_name), revision=revision
        )
203
    elif tokenizer_mode == "custom":
204
        from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
205
206
207
208
209
210
211
212

        tokenizer = TokenizerRegistry.get_tokenizer(
            str(tokenizer_name),
            *args,
            revision=revision,
            download_dir=download_dir,
            **kwargs,
        )
213
214
215
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
216
217
218
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
219
                revision=revision,
220
221
222
223
224
225
226
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
227
228
229
230
231
232
233
234
235
236
                "does not exist or is not currently imported." in str(e)
                or "requires you to execute the tokenizer file" in str(e)
            ):
                err_msg = (
                    "Failed to load the tokenizer. If the tokenizer "
                    "is a custom tokenizer not yet available in the "
                    "HuggingFace transformers library, consider "
                    "setting `trust_remote_code=True` in LLM or using "
                    "the `--trust-remote-code` flag in the CLI."
                )
237
238
239
240
                raise RuntimeError(err_msg) from e
            else:
                raise e

241
242
243
        # The special_tokens in tokenizer should also be
        # controlled by do_lower_case in encoder_config
        encoder_config = get_sentence_transformer_tokenizer_config(
244
245
            tokenizer_name, revision
        )
246
        if isinstance(encoder_config, dict) and encoder_config.get(
247
248
            "do_lower_case", False
        ):
249
            special_tokens_map = {
250
                k: v.lower() for k, v in tokenizer.special_tokens_map.items()
251
252
253
            }
            tokenizer.add_special_tokens(special_tokens_map)

254
255
256
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
257
258
                "slowdown. Consider using a fast tokenizer instead."
            )
259
        tokenizer = get_cached_tokenizer(tokenizer)
260

261
    return tokenizer
262
263


264
265
266
267
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
268
    model_config: ModelConfig,
269
270
271
272
273
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
274
        revision=model_config.tokenizer_revision,
275
276
277
278
279
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


280
281
282
283
284
285
286
287
def init_tokenizer_from_configs(model_config: ModelConfig):
    runner_type = model_config.runner_type
    if runner_type == "generate" or runner_type == "draft":
        truncation_side = "left"
    elif runner_type == "pooling":
        truncation_side = "right"
    else:
        assert_never(runner_type)
288

289
290
291
292
293
294
295
    return get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.tokenizer_revision,
        truncation_side=truncation_side,
    )