tokenizer.py 10.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import copy
6
import os
7
import warnings
8
from functools import lru_cache
9
from pathlib import Path
10
from types import MethodType
11
from typing import TYPE_CHECKING, Any, Optional, Union
12

13
import huggingface_hub
14
from transformers import (AutoTokenizer, PreTrainedTokenizer,
15
16
                          PreTrainedTokenizerFast)

17
from vllm import envs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.logger import init_logger
19
from vllm.transformers_utils.tokenizers import MistralTokenizer
20
from vllm.transformers_utils.utils import check_gguf_file
21
from vllm.utils import make_async
22

23
24
if TYPE_CHECKING:
    from vllm.config import ModelConfig
25
26
27
28
29
30
    from vllm.lora.request import LoRARequest
    from vllm.transformers_utils.tokenizer_base import TokenizerBase
else:
    ModelConfig = Any
    LoRARequest = Any
    TokenizerBase = Any
31

32
33
logger = init_logger(__name__)

34
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
35
                     TokenizerBase]
36

37

38
39
40
41
def decode_tokens(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
42
    skip_special_tokens: Optional[bool] = None,
43
44
45
) -> str:
    """
    Backend-agnostic equivalent of HF's
46
    `tokenizer.decode(token_ids, ...)`.
47

48
    `skip_special_tokens=None` means to use the backend's default
49
    settings.
50
    """
51
52
53
54
55
    if skip_special_tokens is not None:
        return tokenizer.decode(token_ids,
                                skip_special_tokens=skip_special_tokens)

    return tokenizer.decode(token_ids)
56
57


58
59
60
61
def encode_tokens(
    tokenizer: AnyTokenizer,
    text: str,
    *,
62
63
    truncation: Optional[bool] = None,
    max_length: Optional[int] = None,
64
65
66
67
    add_special_tokens: Optional[bool] = None,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
68
    `tokenizer.encode(text, ...)`.
69

70
    `add_special_tokens=None` means to use the backend's default
71
    settings.
72
    """
73
74
75
76
77
78
79
80

    kw_args: dict[str, Any] = {}
    if max_length is not None:
        kw_args["max_length"] = max_length

    if truncation is not None:
        kw_args["truncation"] = truncation

81
    if add_special_tokens is not None:
82
        kw_args["add_special_tokens"] = add_special_tokens
83

84
    return tokenizer.encode(text, **kw_args)
85
86


87
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
88
    """
89
    By default, transformers will recompute multiple tokenizer properties
90
91
92
93
    each time they are called, leading to a significant slowdown.
    This proxy caches these properties for faster access.
    """
    cached_tokenizer = copy.copy(tokenizer)
94

95
96
    tokenizer_all_special_ids = tokenizer.all_special_ids
    tokenizer_all_special_tokens = tokenizer.all_special_tokens
97
98
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
99
    tokenizer_vocab = tokenizer.get_vocab()
100
    tokenizer_len = len(tokenizer)
101

102
    max_token_id = max(tokenizer_vocab.values())
103
104
105
106
107
108
109
    # Some tokenizers (e.g., QwenTokenizer) have special tokens that
    # are added and included in the implementation of the vocab_size
    # property, but not in get_vocab(); if there is an implementation
    # of vocab size, we should take the greater value.
    if hasattr(tokenizer, "vocab_size"):
        with contextlib.suppress(NotImplementedError):
            max_token_id = max(max_token_id, tokenizer.vocab_size)
110

111
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
112
113

        @property
114
        def all_special_ids(self) -> list[int]:
115
116
117
            return tokenizer_all_special_ids

        @property
118
        def all_special_tokens(self) -> list[str]:
119
120
121
            return tokenizer_all_special_tokens

        @property
122
        def all_special_tokens_extended(self) -> list[str]:
123
124
            return tokenizer_all_special_tokens_extended

125
        @property
126
        def max_token_id(self) -> int:
127
128
            return max_token_id

129
        def get_vocab(self) -> dict[str, int]:
130
131
            return tokenizer_vocab

132
        def __len__(self) -> int:
133
134
            return tokenizer_len

135
136
137
        def __reduce__(self):
            return get_cached_tokenizer, (tokenizer, )

138
139
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

140
141
    cached_tokenizer.__class__ = CachedTokenizer
    return cached_tokenizer
142
143


144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
    """Patch _pad method to accept `padding_side` for older tokenizers."""
    orig_pad = tokenizer._pad

    def _pad(
        self: PreTrainedTokenizer,
        *args,
        padding_side: Optional[str] = None,
        **kwargs,
    ):
        if padding_side is not None and padding_side != self.padding_side:
            msg = ("`padding_side` argument is not supported by "
                   f"{type(tokenizer).__name__} and will be ignored.")
            warnings.warn(msg, stacklevel=2)

        return orig_pad(*args, **kwargs)

    tokenizer._pad = MethodType(_pad, tokenizer)


164
def get_tokenizer(
165
    tokenizer_name: Union[str, Path],
166
    *args,
167
    tokenizer_mode: str = "auto",
168
    trust_remote_code: bool = False,
169
    revision: Optional[str] = None,
170
    download_dir: Optional[str] = None,
171
    **kwargs,
172
) -> AnyTokenizer:
173
174
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
    """
175
    if envs.VLLM_USE_MODELSCOPE:
176
177
178
179
180
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

181
182
183
        # avoid circuit import
        from vllm.model_executor.model_loader.weight_utils import get_lock

184
185
        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
186
187
188
189
190
191
192
193
194
195
196
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=tokenizer_name,
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
                tokenizer_name = tokenizer_path
197

198
199
200
201
202
203
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

204
205
206
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

207
    # Separate model folder from file path for GGUF models
208
    is_gguf = check_gguf_file(tokenizer_name)
209
210
211
212
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

213
214
215
216
217
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
            'It is strongly recommended to run mistral models with '
218
            '`--tokenizer-mode "mistral"` to ensure correct '
219
220
221
            'encoding and decoding.',
            FutureWarning,
            stacklevel=2)
222
223

    tokenizer: AnyTokenizer
224
225
226
    if tokenizer_mode == "mistral":
        tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
                                                     revision=revision)
227
    elif tokenizer_mode == "custom":
228
        from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
229
230
231
232
233
        tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
                                                    *args,
                                                    revision=revision,
                                                    download_dir=download_dir,
                                                    **kwargs)
234
235
236
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
237
238
239
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
240
                revision=revision,
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
                    "does not exist or is not currently imported." in str(e)
                    or "requires you to execute the tokenizer file" in str(e)):
                err_msg = ("Failed to load the tokenizer. If the tokenizer "
                           "is a custom tokenizer not yet available in the "
                           "HuggingFace transformers library, consider "
                           "setting `trust_remote_code=True` in LLM or using "
                           "the `--trust-remote-code` flag in the CLI.")
                raise RuntimeError(err_msg) from e
            else:
                raise e

259
260
261
262
        # NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
        if type(tokenizer).__name__ in ("ChatGLMTokenizer",
                                        "ChatGLM4Tokenizer"):
            assert isinstance(tokenizer, PreTrainedTokenizer)
263
            patch_padding_side(tokenizer)
264

265
266
267
268
269
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
                "slowdown. Consider using a fast tokenizer instead.")
        tokenizer = get_cached_tokenizer(tokenizer)
270

271
    return tokenizer
272
273


274
275
276
277
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
278
    model_config: ModelConfig,
279
280
281
282
283
284
285
286
287
288
289
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
        tokenizer_revision=model_config.tokenizer_revision,
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


290
def get_lora_tokenizer(lora_request: LoRARequest, *args,
291
                       **kwargs) -> Optional[AnyTokenizer]:
292
293
294
    if lora_request is None:
        return None
    try:
295
        tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
296
    except Exception as e:
297
298
299
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
300
            "No tokenizer found in %s, using base model tokenizer instead. "
301
            "(Exception: %s)", lora_request.lora_path, e)
302
303
304
305
306
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)