tokenizer.py 10.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import contextlib
4
import os
5
import warnings
6
from functools import lru_cache
7
from pathlib import Path
8
from types import MethodType
9
from typing import TYPE_CHECKING, Any, Optional, Union
10

11
import huggingface_hub
12
from transformers import (AutoTokenizer, PreTrainedTokenizer,
13
14
                          PreTrainedTokenizerFast)

15
from vllm.envs import VLLM_USE_MODELSCOPE
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
19
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
                                                    TokenizerRegistry)
20
from vllm.transformers_utils.tokenizers import MistralTokenizer
21
from vllm.transformers_utils.utils import check_gguf_file
22
from vllm.utils import make_async
23

24
25
26
if TYPE_CHECKING:
    from vllm.config import ModelConfig

27
28
logger = init_logger(__name__)

29
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
30
                     TokenizerBase]
31

32

33
34
35
36
def decode_tokens(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
37
    skip_special_tokens: Optional[bool] = None,
38
39
40
) -> str:
    """
    Backend-agnostic equivalent of HF's
41
42
43
44
    :code:`tokenizer.decode(token_ids, ...)`.

    :code:`skip_special_tokens=None` means to use the backend's default
    settings.
45
    """
46
47
48
49
50
    if skip_special_tokens is not None:
        return tokenizer.decode(token_ids,
                                skip_special_tokens=skip_special_tokens)

    return tokenizer.decode(token_ids)
51
52


53
54
55
56
57
58
59
60
def encode_tokens(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: Optional[bool] = None,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
61
62
63
64
    :code:`tokenizer.encode(text, ...)`.

    :code:`add_special_tokens=None` means to use the backend's default
    settings.
65
    """
66
    if add_special_tokens is not None:
67
        return tokenizer.encode(text, add_special_tokens=add_special_tokens)
68

69
70
71
    return tokenizer.encode(text)


72
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
73
74
75
76
77
78
79
80
81
82
83
84
    """Get tokenizer with cached properties.

    This will patch the tokenizer object in place.

    By default, transformers will recompute multiple tokenizer properties
    each time they are called, leading to a significant slowdown. This
    function caches these properties for faster access."""

    tokenizer_all_special_ids = set(tokenizer.all_special_ids)
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
    tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
85
    tokenizer_vocab = tokenizer.get_vocab()
86
    tokenizer_len = len(tokenizer)
87

88
    max_token_id = max(tokenizer_vocab.values())
89
90
91
92
93
94
95
    # Some tokenizers (e.g., QwenTokenizer) have special tokens that
    # are added and included in the implementation of the vocab_size
    # property, but not in get_vocab(); if there is an implementation
    # of vocab size, we should take the greater value.
    if hasattr(tokenizer, "vocab_size"):
        with contextlib.suppress(NotImplementedError):
            max_token_id = max(max_token_id, tokenizer.vocab_size)
96

97
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
98
99
100
101
102
103
104
105
106
107
108
109
110

        @property
        def all_special_ids(self):
            return tokenizer_all_special_ids

        @property
        def all_special_tokens(self):
            return tokenizer_all_special_tokens

        @property
        def all_special_tokens_extended(self):
            return tokenizer_all_special_tokens_extended

111
112
113
114
        @property
        def max_token_id(self):
            return max_token_id

115
116
117
        def get_vocab(self):
            return tokenizer_vocab

118
119
120
        def __len__(self):
            return tokenizer_len

121
122
123
124
125
126
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

    tokenizer.__class__ = CachedTokenizer
    return tokenizer


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
    """Patch _pad method to accept `padding_side` for older tokenizers."""
    orig_pad = tokenizer._pad

    def _pad(
        self: PreTrainedTokenizer,
        *args,
        padding_side: Optional[str] = None,
        **kwargs,
    ):
        if padding_side is not None and padding_side != self.padding_side:
            msg = ("`padding_side` argument is not supported by "
                   f"{type(tokenizer).__name__} and will be ignored.")
            warnings.warn(msg, stacklevel=2)

        return orig_pad(*args, **kwargs)

    tokenizer._pad = MethodType(_pad, tokenizer)


147
def get_tokenizer(
148
    tokenizer_name: Union[str, Path],
149
    *args,
150
    tokenizer_mode: str = "auto",
151
    trust_remote_code: bool = False,
152
    revision: Optional[str] = None,
153
    download_dir: Optional[str] = None,
154
    **kwargs,
155
) -> AnyTokenizer:
156
157
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
    """
158
159
160
161
162
163
    if VLLM_USE_MODELSCOPE:
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

164
165
166
        # avoid circuit import
        from vllm.model_executor.model_loader.weight_utils import get_lock

167
168
        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
169
170
171
172
173
174
175
176
177
178
179
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=tokenizer_name,
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
                tokenizer_name = tokenizer_path
180

181
182
183
184
185
186
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

187
188
189
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

190
    # Separate model folder from file path for GGUF models
191
    is_gguf = check_gguf_file(tokenizer_name)
192
193
194
195
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

196
197
198
199
200
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
            'It is strongly recommended to run mistral models with '
201
            '`--tokenizer-mode "mistral"` to ensure correct '
202
203
204
            'encoding and decoding.',
            FutureWarning,
            stacklevel=2)
205
206

    tokenizer: AnyTokenizer
207
208
209
    if tokenizer_mode == "mistral":
        tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
                                                     revision=revision)
210
211
212
213
214
215
    elif tokenizer_mode == "custom":
        tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
                                                    *args,
                                                    revision=revision,
                                                    download_dir=download_dir,
                                                    **kwargs)
216
217
218
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
219
220
221
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
222
                revision=revision,
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
                    "does not exist or is not currently imported." in str(e)
                    or "requires you to execute the tokenizer file" in str(e)):
                err_msg = ("Failed to load the tokenizer. If the tokenizer "
                           "is a custom tokenizer not yet available in the "
                           "HuggingFace transformers library, consider "
                           "setting `trust_remote_code=True` in LLM or using "
                           "the `--trust-remote-code` flag in the CLI.")
                raise RuntimeError(err_msg) from e
            else:
                raise e

241
242
243
244
        # NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
        if type(tokenizer).__name__ in ("ChatGLMTokenizer",
                                        "ChatGLM4Tokenizer"):
            assert isinstance(tokenizer, PreTrainedTokenizer)
245
            patch_padding_side(tokenizer)
246

247
248
249
250
251
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
                "slowdown. Consider using a fast tokenizer instead.")
        tokenizer = get_cached_tokenizer(tokenizer)
252

253
    return tokenizer
254
255


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
    model_config: "ModelConfig",
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
        tokenizer_revision=model_config.tokenizer_revision,
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


272
def get_lora_tokenizer(lora_request: LoRARequest, *args,
273
                       **kwargs) -> Optional[AnyTokenizer]:
274
275
276
    if lora_request is None:
        return None
    try:
277
        tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
278
    except Exception as e:
279
280
281
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
282
            "No tokenizer found in %s, using base model tokenizer instead. "
283
            "(Exception: %s)", lora_request.lora_path, e)
284
285
286
287
288
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)