tokenizer.py 9.81 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import copy
6
import os
7
import warnings
8
from functools import lru_cache
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, TypeAlias
11

12
import huggingface_hub
13
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
14
from typing_extensions import assert_never
15

16
from vllm import envs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.logger import init_logger
18
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
19
from vllm.transformers_utils.tokenizers import MistralTokenizer
20
from vllm.transformers_utils.utils import check_gguf_file
21

22
23
if TYPE_CHECKING:
    from vllm.config import ModelConfig
24
25
26
27
    from vllm.transformers_utils.tokenizer_base import TokenizerBase
else:
    ModelConfig = Any
    TokenizerBase = Any
28

29
30
logger = init_logger(__name__)

31
AnyTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast | TokenizerBase
32

33

34
35
36
37
def decode_tokens(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
38
    skip_special_tokens: bool | None = None,
39
40
41
) -> str:
    """
    Backend-agnostic equivalent of HF's
42
    `tokenizer.decode(token_ids, ...)`.
43

44
    `skip_special_tokens=None` means to use the backend's default
45
    settings.
46
    """
47
    if skip_special_tokens is not None:
48
        return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
49

50
    return tokenizer.decode(token_ids)
51
52


53
54
55
56
def encode_tokens(
    tokenizer: AnyTokenizer,
    text: str,
    *,
57
58
59
    truncation: bool | None = None,
    max_length: int | None = None,
    add_special_tokens: bool | None = None,
60
61
62
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
63
    `tokenizer.encode(text, ...)`.
64

65
    `add_special_tokens=None` means to use the backend's default
66
    settings.
67
    """
68
69
70
71
72
73
74
75

    kw_args: dict[str, Any] = {}
    if max_length is not None:
        kw_args["max_length"] = max_length

    if truncation is not None:
        kw_args["truncation"] = truncation

76
    if add_special_tokens is not None:
77
        kw_args["add_special_tokens"] = add_special_tokens
78

79
    return tokenizer.encode(text, **kw_args)
80
81


82
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
83
    """
84
    By default, transformers will recompute multiple tokenizer properties
85
86
87
88
    each time they are called, leading to a significant slowdown.
    This proxy caches these properties for faster access.
    """
    cached_tokenizer = copy.copy(tokenizer)
89

90
91
    tokenizer_all_special_ids = tokenizer.all_special_ids
    tokenizer_all_special_tokens = tokenizer.all_special_tokens
92
    tokenizer_all_special_tokens_extended = tokenizer.all_special_tokens_extended
93
    tokenizer_vocab = tokenizer.get_vocab()
94
    tokenizer_len = len(tokenizer)
95

96
    max_token_id = max(tokenizer_vocab.values())
97
98
99
100
101
102
103
    # Some tokenizers (e.g., QwenTokenizer) have special tokens that
    # are added and included in the implementation of the vocab_size
    # property, but not in get_vocab(); if there is an implementation
    # of vocab size, we should take the greater value.
    if hasattr(tokenizer, "vocab_size"):
        with contextlib.suppress(NotImplementedError):
            max_token_id = max(max_token_id, tokenizer.vocab_size)
104

105
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
106
        @property
107
        def all_special_ids(self) -> list[int]:
108
109
110
            return tokenizer_all_special_ids

        @property
111
        def all_special_tokens(self) -> list[str]:
112
113
114
            return tokenizer_all_special_tokens

        @property
115
        def all_special_tokens_extended(self) -> list[str]:
116
117
            return tokenizer_all_special_tokens_extended

118
        @property
119
        def max_token_id(self) -> int:
120
121
            return max_token_id

122
        def get_vocab(self) -> dict[str, int]:
123
124
            return tokenizer_vocab

125
        def __len__(self) -> int:
126
127
            return tokenizer_len

128
        def __reduce__(self):
129
            return get_cached_tokenizer, (tokenizer,)
130

131
132
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

133
134
    cached_tokenizer.__class__ = CachedTokenizer
    return cached_tokenizer
135
136


137
def get_tokenizer(
138
    tokenizer_name: str | Path,
139
    *args,
140
    tokenizer_mode: str = "auto",
141
    trust_remote_code: bool = False,
142
143
    revision: str | None = None,
    download_dir: str | None = None,
144
    **kwargs,
145
) -> AnyTokenizer:
146
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
147
    if envs.VLLM_USE_MODELSCOPE:
148
149
150
151
152
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

153
154
155
        # avoid circuit import
        from vllm.model_executor.model_loader.weight_utils import get_lock

156
157
        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
158
159
160
161
162
163
164
165
166
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=tokenizer_name,
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
167
168
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
                )
169
                tokenizer_name = tokenizer_path
170

171
172
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
173
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
174
175
        kwargs["use_fast"] = False

176
177
178
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

179
    # Separate model folder from file path for GGUF models
180
    is_gguf = check_gguf_file(tokenizer_name)
181
182
183
184
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

185
186
187
188
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
189
            "It is strongly recommended to run mistral models with "
190
            '`--tokenizer-mode "mistral"` to ensure correct '
191
            "encoding and decoding.",
192
            FutureWarning,
193
194
            stacklevel=2,
        )
195
196

    tokenizer: AnyTokenizer
197
    if tokenizer_mode == "mistral":
198
199
200
        tokenizer = MistralTokenizer.from_pretrained(
            str(tokenizer_name), revision=revision
        )
201
    elif tokenizer_mode == "custom":
202
        from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
203
204
205
206
207
208
209
210

        tokenizer = TokenizerRegistry.get_tokenizer(
            str(tokenizer_name),
            *args,
            revision=revision,
            download_dir=download_dir,
            **kwargs,
        )
211
212
213
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
214
215
216
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
217
                revision=revision,
218
219
220
221
222
223
224
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
225
226
227
228
229
230
231
232
233
234
                "does not exist or is not currently imported." in str(e)
                or "requires you to execute the tokenizer file" in str(e)
            ):
                err_msg = (
                    "Failed to load the tokenizer. If the tokenizer "
                    "is a custom tokenizer not yet available in the "
                    "HuggingFace transformers library, consider "
                    "setting `trust_remote_code=True` in LLM or using "
                    "the `--trust-remote-code` flag in the CLI."
                )
235
236
237
238
                raise RuntimeError(err_msg) from e
            else:
                raise e

239
240
241
        # The special_tokens in tokenizer should also be
        # controlled by do_lower_case in encoder_config
        encoder_config = get_sentence_transformer_tokenizer_config(
242
243
            tokenizer_name, revision
        )
244
        if isinstance(encoder_config, dict) and encoder_config.get(
245
246
            "do_lower_case", False
        ):
247
            special_tokens_map = {
248
                k: v.lower() for k, v in tokenizer.special_tokens_map.items()
249
250
251
            }
            tokenizer.add_special_tokens(special_tokens_map)

252
253
254
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
255
256
                "slowdown. Consider using a fast tokenizer instead."
            )
257
        tokenizer = get_cached_tokenizer(tokenizer)
258

259
    return tokenizer
260
261


262
263
264
265
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
266
    model_config: ModelConfig,
267
268
269
270
271
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
272
        revision=model_config.tokenizer_revision,
273
274
275
276
277
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


278
279
280
281
282
283
284
285
def init_tokenizer_from_configs(model_config: ModelConfig):
    runner_type = model_config.runner_type
    if runner_type == "generate" or runner_type == "draft":
        truncation_side = "left"
    elif runner_type == "pooling":
        truncation_side = "right"
    else:
        assert_never(runner_type)
286

287
288
289
290
291
292
293
    return get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.tokenizer_revision,
        truncation_side=truncation_side,
    )