tokenizer.py 10.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import copy
6
import importlib.util
7
import os
8
from functools import lru_cache
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, TypeAlias
11

12
import huggingface_hub
13
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
14
from typing_extensions import assert_never
15

16
from vllm import envs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.logger import init_logger
18
19
20
21
from vllm.transformers_utils.config import (
    get_sentence_transformer_tokenizer_config,
    list_filtered_repo_files,
)
22
from vllm.transformers_utils.tokenizers import MistralTokenizer
23
24
25
26
27
28
from vllm.transformers_utils.utils import (
    check_gguf_file,
    is_gguf,
    is_remote_gguf,
    split_remote_gguf,
)
29

30
31
if TYPE_CHECKING:
    from vllm.config import ModelConfig
32
33
34
35
    from vllm.transformers_utils.tokenizer_base import TokenizerBase
else:
    ModelConfig = Any
    TokenizerBase = Any
36

37
38
logger = init_logger(__name__)

39
AnyTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast | TokenizerBase
40

41

42
43
44
45
def decode_tokens(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
46
    skip_special_tokens: bool | None = None,
47
48
49
) -> str:
    """
    Backend-agnostic equivalent of HF's
50
    `tokenizer.decode(token_ids, ...)`.
51

52
    `skip_special_tokens=None` means to use the backend's default
53
    settings.
54
    """
55
    if skip_special_tokens is not None:
56
        return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
57

58
    return tokenizer.decode(token_ids)
59
60


61
62
63
64
def encode_tokens(
    tokenizer: AnyTokenizer,
    text: str,
    *,
65
66
67
    truncation: bool | None = None,
    max_length: int | None = None,
    add_special_tokens: bool | None = None,
68
69
70
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
71
    `tokenizer.encode(text, ...)`.
72

73
    `add_special_tokens=None` means to use the backend's default
74
    settings.
75
    """
76
77
78
79
80
81
82
83

    kw_args: dict[str, Any] = {}
    if max_length is not None:
        kw_args["max_length"] = max_length

    if truncation is not None:
        kw_args["truncation"] = truncation

84
    if add_special_tokens is not None:
85
        kw_args["add_special_tokens"] = add_special_tokens
86

87
    return tokenizer.encode(text, **kw_args)
88
89


90
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
91
    """
92
    By default, transformers will recompute multiple tokenizer properties
93
94
95
96
    each time they are called, leading to a significant slowdown.
    This proxy caches these properties for faster access.
    """
    cached_tokenizer = copy.copy(tokenizer)
97

98
99
    tokenizer_all_special_ids = tokenizer.all_special_ids
    tokenizer_all_special_tokens = tokenizer.all_special_tokens
100
    tokenizer_all_special_tokens_extended = tokenizer.all_special_tokens_extended
101
    tokenizer_vocab = tokenizer.get_vocab()
102
    tokenizer_len = len(tokenizer)
103

104
    max_token_id = max(tokenizer_vocab.values())
105
106
107
108
109
110
111
    # Some tokenizers (e.g., QwenTokenizer) have special tokens that
    # are added and included in the implementation of the vocab_size
    # property, but not in get_vocab(); if there is an implementation
    # of vocab size, we should take the greater value.
    if hasattr(tokenizer, "vocab_size"):
        with contextlib.suppress(NotImplementedError):
            max_token_id = max(max_token_id, tokenizer.vocab_size)
112

113
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
114
        @property
115
        def all_special_ids(self) -> list[int]:
116
117
118
            return tokenizer_all_special_ids

        @property
119
        def all_special_tokens(self) -> list[str]:
120
121
122
            return tokenizer_all_special_tokens

        @property
123
        def all_special_tokens_extended(self) -> list[str]:
124
125
            return tokenizer_all_special_tokens_extended

126
        @property
127
        def max_token_id(self) -> int:
128
129
            return max_token_id

130
        def get_vocab(self) -> dict[str, int]:
131
132
            return tokenizer_vocab

133
        def __len__(self) -> int:
134
135
            return tokenizer_len

136
        def __reduce__(self):
137
            return get_cached_tokenizer, (tokenizer,)
138

139
140
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

141
142
    cached_tokenizer.__class__ = CachedTokenizer
    return cached_tokenizer
143
144


145
def get_tokenizer(
146
    tokenizer_name: str | Path,
147
    *args,
148
    tokenizer_mode: str = "auto",
149
    trust_remote_code: bool = False,
150
151
    revision: str | None = None,
    download_dir: str | None = None,
152
    **kwargs,
153
) -> AnyTokenizer:
154
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
155
    if envs.VLLM_USE_MODELSCOPE:
156
157
158
159
160
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

161
162
163
        # avoid circuit import
        from vllm.model_executor.model_loader.weight_utils import get_lock

164
165
        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
166
167
168
169
170
171
172
173
174
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=tokenizer_name,
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
175
176
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
                )
177
                tokenizer_name = tokenizer_path
178

179
180
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
181
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
182
183
        kwargs["use_fast"] = False

184
185
186
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

187
    # Separate model folder from file path for GGUF models
188
189
190
191
192
193
    if is_gguf(tokenizer_name):
        if check_gguf_file(tokenizer_name):
            kwargs["gguf_file"] = Path(tokenizer_name).name
            tokenizer_name = Path(tokenizer_name).parent
        elif is_remote_gguf(tokenizer_name):
            tokenizer_name, _ = split_remote_gguf(tokenizer_name)
194

195
196
197
198
199
200
201
202
203
    # if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
    # first to use official Mistral tokenizer if possible.
    mistral_common_installed = importlib.util.find_spec("mistral_common") is not None
    if tokenizer_mode == "auto" and mistral_common_installed:
        allow_patterns = ["tekken.json", "tokenizer.model.v*"]
        files_list = list_filtered_repo_files(
            model_name_or_path=str(tokenizer_name),
            allow_patterns=allow_patterns,
            revision=revision,
204
        )
205
206
        if len(files_list) > 0:
            tokenizer_mode = "mistral"
207
208

    tokenizer: AnyTokenizer
209
    if tokenizer_mode == "mistral":
210
        logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}")
211
212
213
        tokenizer = MistralTokenizer.from_pretrained(
            str(tokenizer_name), revision=revision
        )
214
    elif tokenizer_mode == "custom":
215
        from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
216

217
        logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}")
218
219
220
221
222
223
224
        tokenizer = TokenizerRegistry.get_tokenizer(
            str(tokenizer_name),
            *args,
            revision=revision,
            download_dir=download_dir,
            **kwargs,
        )
225
226
    else:
        try:
227
            logger.debug_once(f"Loading AutoTokenizer from {tokenizer_name}")
228
            tokenizer = AutoTokenizer.from_pretrained(
229
230
231
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
232
                revision=revision,
233
234
235
236
237
238
239
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
240
241
242
243
244
245
246
247
248
249
                "does not exist or is not currently imported." in str(e)
                or "requires you to execute the tokenizer file" in str(e)
            ):
                err_msg = (
                    "Failed to load the tokenizer. If the tokenizer "
                    "is a custom tokenizer not yet available in the "
                    "HuggingFace transformers library, consider "
                    "setting `trust_remote_code=True` in LLM or using "
                    "the `--trust-remote-code` flag in the CLI."
                )
250
251
252
253
                raise RuntimeError(err_msg) from e
            else:
                raise e

254
255
256
        # The special_tokens in tokenizer should also be
        # controlled by do_lower_case in encoder_config
        encoder_config = get_sentence_transformer_tokenizer_config(
257
258
            tokenizer_name, revision
        )
259
        if isinstance(encoder_config, dict) and encoder_config.get(
260
261
            "do_lower_case", False
        ):
262
            special_tokens_map = {
263
                k: v.lower() for k, v in tokenizer.special_tokens_map.items()
264
265
266
            }
            tokenizer.add_special_tokens(special_tokens_map)

267
268
269
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
270
271
                "slowdown. Consider using a fast tokenizer instead."
            )
272
        tokenizer = get_cached_tokenizer(tokenizer)
273

274
    return tokenizer
275
276


277
278
279
280
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
281
    model_config: ModelConfig,
282
283
284
285
286
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
287
        revision=model_config.tokenizer_revision,
288
289
290
291
292
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


293
294
295
296
297
298
299
300
def init_tokenizer_from_configs(model_config: ModelConfig):
    runner_type = model_config.runner_type
    if runner_type == "generate" or runner_type == "draft":
        truncation_side = "left"
    elif runner_type == "pooling":
        truncation_side = "right"
    else:
        assert_never(runner_type)
301

302
303
304
305
306
307
308
    return get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.tokenizer_revision,
        truncation_side=truncation_side,
    )