tokenizer.py 10.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import copy
6
import os
7
import warnings
8
from functools import lru_cache
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, Optional, Union
11

12
import huggingface_hub
13
from transformers import (AutoTokenizer, PreTrainedTokenizer,
14
                          PreTrainedTokenizerFast)
15
from typing_extensions import assert_never
16

17
from vllm import envs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.logger import init_logger
19
20
from vllm.transformers_utils.config import (
    get_sentence_transformer_tokenizer_config)
21
from vllm.transformers_utils.tokenizers import MistralTokenizer
22
from vllm.transformers_utils.utils import check_gguf_file
23

24
25
if TYPE_CHECKING:
    from vllm.config import ModelConfig
26
27
28
29
30
31
    from vllm.lora.request import LoRARequest
    from vllm.transformers_utils.tokenizer_base import TokenizerBase
else:
    ModelConfig = Any
    LoRARequest = Any
    TokenizerBase = Any
32

33
34
logger = init_logger(__name__)

35
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
36
                     TokenizerBase]
37

38

39
40
41
42
def decode_tokens(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
43
    skip_special_tokens: Optional[bool] = None,
44
45
46
) -> str:
    """
    Backend-agnostic equivalent of HF's
47
    `tokenizer.decode(token_ids, ...)`.
48

49
    `skip_special_tokens=None` means to use the backend's default
50
    settings.
51
    """
52
    if skip_special_tokens is not None:
53
54
        return tokenizer.decode(token_ids,
                                skip_special_tokens=skip_special_tokens)
55

56
    return tokenizer.decode(token_ids)
57
58


59
60
61
62
def encode_tokens(
    tokenizer: AnyTokenizer,
    text: str,
    *,
63
64
    truncation: Optional[bool] = None,
    max_length: Optional[int] = None,
65
66
67
68
    add_special_tokens: Optional[bool] = None,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
69
    `tokenizer.encode(text, ...)`.
70

71
    `add_special_tokens=None` means to use the backend's default
72
    settings.
73
    """
74
75
76
77
78
79
80
81

    kw_args: dict[str, Any] = {}
    if max_length is not None:
        kw_args["max_length"] = max_length

    if truncation is not None:
        kw_args["truncation"] = truncation

82
    if add_special_tokens is not None:
83
        kw_args["add_special_tokens"] = add_special_tokens
84

85
    return tokenizer.encode(text, **kw_args)
86
87


88
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
89
    """
90
    By default, transformers will recompute multiple tokenizer properties
91
92
93
94
    each time they are called, leading to a significant slowdown.
    This proxy caches these properties for faster access.
    """
    cached_tokenizer = copy.copy(tokenizer)
95

96
97
    tokenizer_all_special_ids = tokenizer.all_special_ids
    tokenizer_all_special_tokens = tokenizer.all_special_tokens
98
99
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
100
    tokenizer_vocab = tokenizer.get_vocab()
101
    tokenizer_len = len(tokenizer)
102

103
    max_token_id = max(tokenizer_vocab.values())
104
105
106
107
108
109
110
    # Some tokenizers (e.g., QwenTokenizer) have special tokens that
    # are added and included in the implementation of the vocab_size
    # property, but not in get_vocab(); if there is an implementation
    # of vocab size, we should take the greater value.
    if hasattr(tokenizer, "vocab_size"):
        with contextlib.suppress(NotImplementedError):
            max_token_id = max(max_token_id, tokenizer.vocab_size)
111

112
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
113
114

        @property
115
        def all_special_ids(self) -> list[int]:
116
117
118
            return tokenizer_all_special_ids

        @property
119
        def all_special_tokens(self) -> list[str]:
120
121
122
            return tokenizer_all_special_tokens

        @property
123
        def all_special_tokens_extended(self) -> list[str]:
124
125
            return tokenizer_all_special_tokens_extended

126
        @property
127
        def max_token_id(self) -> int:
128
129
            return max_token_id

130
        def get_vocab(self) -> dict[str, int]:
131
132
            return tokenizer_vocab

133
        def __len__(self) -> int:
134
135
            return tokenizer_len

136
137
138
        def __reduce__(self):
            return get_cached_tokenizer, (tokenizer, )

139
140
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

141
142
    cached_tokenizer.__class__ = CachedTokenizer
    return cached_tokenizer
143
144


145
def get_tokenizer(
146
    tokenizer_name: Union[str, Path],
147
    *args,
148
    tokenizer_mode: str = "auto",
149
    trust_remote_code: bool = False,
150
    revision: Optional[str] = None,
151
    download_dir: Optional[str] = None,
152
    **kwargs,
153
) -> AnyTokenizer:
154
155
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
    """
156
    if envs.VLLM_USE_MODELSCOPE:
157
158
159
160
161
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

162
163
164
        # avoid circuit import
        from vllm.model_executor.model_loader.weight_utils import get_lock

165
166
        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
167
168
169
170
171
172
173
174
175
176
177
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=tokenizer_name,
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
                tokenizer_name = tokenizer_path
178

179
180
181
182
183
184
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

185
186
187
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

188
    # Separate model folder from file path for GGUF models
189
    is_gguf = check_gguf_file(tokenizer_name)
190
191
192
193
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

194
195
196
197
198
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
            'It is strongly recommended to run mistral models with '
199
            '`--tokenizer-mode "mistral"` to ensure correct '
200
201
202
            'encoding and decoding.',
            FutureWarning,
            stacklevel=2)
203
204

    tokenizer: AnyTokenizer
205
206
207
    if tokenizer_mode == "mistral":
        tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
                                                     revision=revision)
208
    elif tokenizer_mode == "custom":
209
        from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
210
211
212
213
214
        tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
                                                    *args,
                                                    revision=revision,
                                                    download_dir=download_dir,
                                                    **kwargs)
215
216
217
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
218
219
220
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
221
                revision=revision,
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
                    "does not exist or is not currently imported." in str(e)
                    or "requires you to execute the tokenizer file" in str(e)):
                err_msg = ("Failed to load the tokenizer. If the tokenizer "
                           "is a custom tokenizer not yet available in the "
                           "HuggingFace transformers library, consider "
                           "setting `trust_remote_code=True` in LLM or using "
                           "the `--trust-remote-code` flag in the CLI.")
                raise RuntimeError(err_msg) from e
            else:
                raise e

240
241
242
243
244
245
246
247
248
249
250
251
        # The special_tokens in tokenizer should also be
        # controlled by do_lower_case in encoder_config
        encoder_config = get_sentence_transformer_tokenizer_config(
            tokenizer_name, revision)
        if isinstance(encoder_config, dict) and encoder_config.get(
                "do_lower_case", False):
            special_tokens_map = {
                k: v.lower()
                for k, v in tokenizer.special_tokens_map.items()
            }
            tokenizer.add_special_tokens(special_tokens_map)

252
253
254
255
256
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
                "slowdown. Consider using a fast tokenizer instead.")
        tokenizer = get_cached_tokenizer(tokenizer)
257

258
    return tokenizer
259
260


261
262
263
264
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
265
    model_config: ModelConfig,
266
267
268
269
270
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
271
        revision=model_config.tokenizer_revision,
272
273
274
275
276
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


277
278
279
280
281
282
283
284
def init_tokenizer_from_configs(model_config: ModelConfig):
    runner_type = model_config.runner_type
    if runner_type == "generate" or runner_type == "draft":
        truncation_side = "left"
    elif runner_type == "pooling":
        truncation_side = "right"
    else:
        assert_never(runner_type)
285

286
287
288
289
290
291
292
    return get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.tokenizer_revision,
        truncation_side=truncation_side,
    )