tokenizer.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import copy
6
import os
7
import warnings
8
from functools import lru_cache
9
from pathlib import Path
10
from types import MethodType
11
from typing import TYPE_CHECKING, Any, Optional, Union
12

13
import huggingface_hub
14
from transformers import (AutoTokenizer, PreTrainedTokenizer,
15
16
                          PreTrainedTokenizerFast)

17
from vllm import envs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
21
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
                                                    TokenizerRegistry)
22
from vllm.transformers_utils.tokenizers import MistralTokenizer
23
from vllm.transformers_utils.utils import check_gguf_file
24
from vllm.utils import make_async
25

26
27
28
if TYPE_CHECKING:
    from vllm.config import ModelConfig

29
30
logger = init_logger(__name__)

31
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
32
                     TokenizerBase]
33

34

35
36
37
38
def decode_tokens(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
39
    skip_special_tokens: Optional[bool] = None,
40
41
42
) -> str:
    """
    Backend-agnostic equivalent of HF's
43
    `tokenizer.decode(token_ids, ...)`.
44

45
    `skip_special_tokens=None` means to use the backend's default
46
    settings.
47
    """
48
49
50
51
52
    if skip_special_tokens is not None:
        return tokenizer.decode(token_ids,
                                skip_special_tokens=skip_special_tokens)

    return tokenizer.decode(token_ids)
53
54


55
56
57
58
def encode_tokens(
    tokenizer: AnyTokenizer,
    text: str,
    *,
59
60
    truncation: Optional[bool] = None,
    max_length: Optional[int] = None,
61
62
63
64
    add_special_tokens: Optional[bool] = None,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
65
    `tokenizer.encode(text, ...)`.
66

67
    `add_special_tokens=None` means to use the backend's default
68
    settings.
69
    """
70
71
72
73
74
75
76
77

    kw_args: dict[str, Any] = {}
    if max_length is not None:
        kw_args["max_length"] = max_length

    if truncation is not None:
        kw_args["truncation"] = truncation

78
    if add_special_tokens is not None:
79
        kw_args["add_special_tokens"] = add_special_tokens
80

81
    return tokenizer.encode(text, **kw_args)
82
83


84
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
85
    """
86
    By default, transformers will recompute multiple tokenizer properties
87
88
89
90
    each time they are called, leading to a significant slowdown.
    This proxy caches these properties for faster access.
    """
    cached_tokenizer = copy.copy(tokenizer)
91

92
93
    tokenizer_all_special_ids = tokenizer.all_special_ids
    tokenizer_all_special_tokens = tokenizer.all_special_tokens
94
95
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
96
    tokenizer_vocab = tokenizer.get_vocab()
97
    tokenizer_len = len(tokenizer)
98

99
    max_token_id = max(tokenizer_vocab.values())
100
101
102
103
104
105
106
    # Some tokenizers (e.g., QwenTokenizer) have special tokens that
    # are added and included in the implementation of the vocab_size
    # property, but not in get_vocab(); if there is an implementation
    # of vocab size, we should take the greater value.
    if hasattr(tokenizer, "vocab_size"):
        with contextlib.suppress(NotImplementedError):
            max_token_id = max(max_token_id, tokenizer.vocab_size)
107

108
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
109
110

        @property
111
        def all_special_ids(self) -> list[int]:
112
113
114
            return tokenizer_all_special_ids

        @property
115
        def all_special_tokens(self) -> list[str]:
116
117
118
            return tokenizer_all_special_tokens

        @property
119
        def all_special_tokens_extended(self) -> list[str]:
120
121
            return tokenizer_all_special_tokens_extended

122
        @property
123
        def max_token_id(self) -> int:
124
125
            return max_token_id

126
        def get_vocab(self) -> dict[str, int]:
127
128
            return tokenizer_vocab

129
        def __len__(self) -> int:
130
131
            return tokenizer_len

132
133
134
        def __reduce__(self):
            return get_cached_tokenizer, (tokenizer, )

135
136
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

137
138
    cached_tokenizer.__class__ = CachedTokenizer
    return cached_tokenizer
139
140


141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
    """Patch _pad method to accept `padding_side` for older tokenizers."""
    orig_pad = tokenizer._pad

    def _pad(
        self: PreTrainedTokenizer,
        *args,
        padding_side: Optional[str] = None,
        **kwargs,
    ):
        if padding_side is not None and padding_side != self.padding_side:
            msg = ("`padding_side` argument is not supported by "
                   f"{type(tokenizer).__name__} and will be ignored.")
            warnings.warn(msg, stacklevel=2)

        return orig_pad(*args, **kwargs)

    tokenizer._pad = MethodType(_pad, tokenizer)


161
def get_tokenizer(
162
    tokenizer_name: Union[str, Path],
163
    *args,
164
    tokenizer_mode: str = "auto",
165
    trust_remote_code: bool = False,
166
    revision: Optional[str] = None,
167
    download_dir: Optional[str] = None,
168
    **kwargs,
169
) -> AnyTokenizer:
170
171
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
    """
172
    if envs.VLLM_USE_MODELSCOPE:
173
174
175
176
177
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

178
179
180
        # avoid circuit import
        from vllm.model_executor.model_loader.weight_utils import get_lock

181
182
        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
183
184
185
186
187
188
189
190
191
192
193
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=tokenizer_name,
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
                tokenizer_name = tokenizer_path
194

195
196
197
198
199
200
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

201
202
203
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

204
    # Separate model folder from file path for GGUF models
205
    is_gguf = check_gguf_file(tokenizer_name)
206
207
208
209
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

210
211
212
213
214
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
            'It is strongly recommended to run mistral models with '
215
            '`--tokenizer-mode "mistral"` to ensure correct '
216
217
218
            'encoding and decoding.',
            FutureWarning,
            stacklevel=2)
219
220

    tokenizer: AnyTokenizer
221
222
223
    if tokenizer_mode == "mistral":
        tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
                                                     revision=revision)
224
225
226
227
228
229
    elif tokenizer_mode == "custom":
        tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
                                                    *args,
                                                    revision=revision,
                                                    download_dir=download_dir,
                                                    **kwargs)
230
231
232
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
233
234
235
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
236
                revision=revision,
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
                    "does not exist or is not currently imported." in str(e)
                    or "requires you to execute the tokenizer file" in str(e)):
                err_msg = ("Failed to load the tokenizer. If the tokenizer "
                           "is a custom tokenizer not yet available in the "
                           "HuggingFace transformers library, consider "
                           "setting `trust_remote_code=True` in LLM or using "
                           "the `--trust-remote-code` flag in the CLI.")
                raise RuntimeError(err_msg) from e
            else:
                raise e

255
256
257
258
        # NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
        if type(tokenizer).__name__ in ("ChatGLMTokenizer",
                                        "ChatGLM4Tokenizer"):
            assert isinstance(tokenizer, PreTrainedTokenizer)
259
            patch_padding_side(tokenizer)
260

261
262
263
264
265
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
                "slowdown. Consider using a fast tokenizer instead.")
        tokenizer = get_cached_tokenizer(tokenizer)
266

267
    return tokenizer
268
269


270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
    model_config: "ModelConfig",
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
        tokenizer_revision=model_config.tokenizer_revision,
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


286
def get_lora_tokenizer(lora_request: LoRARequest, *args,
287
                       **kwargs) -> Optional[AnyTokenizer]:
288
289
290
    if lora_request is None:
        return None
    try:
291
        tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
292
    except Exception as e:
293
294
295
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
296
            "No tokenizer found in %s, using base model tokenizer instead. "
297
            "(Exception: %s)", lora_request.lora_path, e)
298
299
300
301
302
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)