tokenizer.py 10.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import copy
6
import os
7
import warnings
8
from functools import lru_cache
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, Optional, Union
11

12
import huggingface_hub
13
from transformers import (AutoTokenizer, PreTrainedTokenizer,
14
15
                          PreTrainedTokenizerFast)

16
from vllm import envs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.logger import init_logger
18
19
from vllm.transformers_utils.config import (
    get_sentence_transformer_tokenizer_config)
20
from vllm.transformers_utils.tokenizers import MistralTokenizer
21
from vllm.transformers_utils.utils import check_gguf_file
22
from vllm.utils import make_async
23

24
25
if TYPE_CHECKING:
    from vllm.config import ModelConfig
26
27
28
29
30
31
    from vllm.lora.request import LoRARequest
    from vllm.transformers_utils.tokenizer_base import TokenizerBase
else:
    ModelConfig = Any
    LoRARequest = Any
    TokenizerBase = Any
32

33
34
logger = init_logger(__name__)

35
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
36
                     TokenizerBase]
37

38

39
40
41
42
def decode_tokens(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
43
    skip_special_tokens: Optional[bool] = None,
44
45
46
) -> str:
    """
    Backend-agnostic equivalent of HF's
47
    `tokenizer.decode(token_ids, ...)`.
48

49
    `skip_special_tokens=None` means to use the backend's default
50
    settings.
51
    """
52
    decode_method = getattr(tokenizer, "_decode", tokenizer.decode)
53
    if skip_special_tokens is not None:
54
55
        return decode_method(token_ids,
                             skip_special_tokens=skip_special_tokens)
56

57
    return decode_method(token_ids)
58
59


60
61
62
63
def encode_tokens(
    tokenizer: AnyTokenizer,
    text: str,
    *,
64
65
    truncation: Optional[bool] = None,
    max_length: Optional[int] = None,
66
67
68
69
    add_special_tokens: Optional[bool] = None,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
70
    `tokenizer.encode(text, ...)`.
71

72
    `add_special_tokens=None` means to use the backend's default
73
    settings.
74
    """
75
76
77
78
79
80
81
82

    kw_args: dict[str, Any] = {}
    if max_length is not None:
        kw_args["max_length"] = max_length

    if truncation is not None:
        kw_args["truncation"] = truncation

83
    if add_special_tokens is not None:
84
        kw_args["add_special_tokens"] = add_special_tokens
85

86
    return tokenizer.encode(text, **kw_args)
87
88


89
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
90
    """
91
    By default, transformers will recompute multiple tokenizer properties
92
93
94
95
    each time they are called, leading to a significant slowdown.
    This proxy caches these properties for faster access.
    """
    cached_tokenizer = copy.copy(tokenizer)
96

97
98
    tokenizer_all_special_ids = tokenizer.all_special_ids
    tokenizer_all_special_tokens = tokenizer.all_special_tokens
99
100
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
101
    tokenizer_vocab = tokenizer.get_vocab()
102
    tokenizer_len = len(tokenizer)
103

104
    max_token_id = max(tokenizer_vocab.values())
105
106
107
108
109
110
111
    # Some tokenizers (e.g., QwenTokenizer) have special tokens that
    # are added and included in the implementation of the vocab_size
    # property, but not in get_vocab(); if there is an implementation
    # of vocab size, we should take the greater value.
    if hasattr(tokenizer, "vocab_size"):
        with contextlib.suppress(NotImplementedError):
            max_token_id = max(max_token_id, tokenizer.vocab_size)
112

113
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
114
115

        @property
116
        def all_special_ids(self) -> list[int]:
117
118
119
            return tokenizer_all_special_ids

        @property
120
        def all_special_tokens(self) -> list[str]:
121
122
123
            return tokenizer_all_special_tokens

        @property
124
        def all_special_tokens_extended(self) -> list[str]:
125
126
            return tokenizer_all_special_tokens_extended

127
        @property
128
        def max_token_id(self) -> int:
129
130
            return max_token_id

131
        def get_vocab(self) -> dict[str, int]:
132
133
            return tokenizer_vocab

134
        def __len__(self) -> int:
135
136
            return tokenizer_len

137
138
139
        def __reduce__(self):
            return get_cached_tokenizer, (tokenizer, )

140
141
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

142
143
    cached_tokenizer.__class__ = CachedTokenizer
    return cached_tokenizer
144
145


146
def get_tokenizer(
147
    tokenizer_name: Union[str, Path],
148
    *args,
149
    tokenizer_mode: str = "auto",
150
    trust_remote_code: bool = False,
151
    revision: Optional[str] = None,
152
    download_dir: Optional[str] = None,
153
    **kwargs,
154
) -> AnyTokenizer:
155
156
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
    """
157
    if envs.VLLM_USE_MODELSCOPE:
158
159
160
161
162
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

163
164
165
        # avoid circuit import
        from vllm.model_executor.model_loader.weight_utils import get_lock

166
167
        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
168
169
170
171
172
173
174
175
176
177
178
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=tokenizer_name,
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
                tokenizer_name = tokenizer_path
179

180
181
182
183
184
185
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

186
187
188
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

189
    # Separate model folder from file path for GGUF models
190
    is_gguf = check_gguf_file(tokenizer_name)
191
192
193
194
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

195
196
197
198
199
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
            'It is strongly recommended to run mistral models with '
200
            '`--tokenizer-mode "mistral"` to ensure correct '
201
202
203
            'encoding and decoding.',
            FutureWarning,
            stacklevel=2)
204
205

    tokenizer: AnyTokenizer
206
207
208
    if tokenizer_mode == "mistral":
        tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
                                                     revision=revision)
209
    elif tokenizer_mode == "custom":
210
        from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
211
212
213
214
215
        tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
                                                    *args,
                                                    revision=revision,
                                                    download_dir=download_dir,
                                                    **kwargs)
216
217
218
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
219
220
221
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
222
                revision=revision,
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
                    "does not exist or is not currently imported." in str(e)
                    or "requires you to execute the tokenizer file" in str(e)):
                err_msg = ("Failed to load the tokenizer. If the tokenizer "
                           "is a custom tokenizer not yet available in the "
                           "HuggingFace transformers library, consider "
                           "setting `trust_remote_code=True` in LLM or using "
                           "the `--trust-remote-code` flag in the CLI.")
                raise RuntimeError(err_msg) from e
            else:
                raise e

241
242
243
244
245
246
247
248
249
250
251
252
        # The special_tokens in tokenizer should also be
        # controlled by do_lower_case in encoder_config
        encoder_config = get_sentence_transformer_tokenizer_config(
            tokenizer_name, revision)
        if isinstance(encoder_config, dict) and encoder_config.get(
                "do_lower_case", False):
            special_tokens_map = {
                k: v.lower()
                for k, v in tokenizer.special_tokens_map.items()
            }
            tokenizer.add_special_tokens(special_tokens_map)

253
254
255
256
257
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
                "slowdown. Consider using a fast tokenizer instead.")
        tokenizer = get_cached_tokenizer(tokenizer)
258

259
    return tokenizer
260
261


262
263
264
265
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
266
    model_config: ModelConfig,
267
268
269
270
271
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
272
        revision=model_config.tokenizer_revision,
273
274
275
276
277
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


278
def get_lora_tokenizer(lora_request: LoRARequest, *args,
279
                       **kwargs) -> Optional[AnyTokenizer]:
280
281
282
    if lora_request is None:
        return None
    try:
283
        tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
284
    except Exception as e:
285
286
287
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
288
            "No tokenizer found in %s, using base model tokenizer instead. "
289
            "(Exception: %s)", lora_request.lora_path, e)
290
291
292
293
294
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)