tokenizer.py 8.56 KB
Newer Older
1
import contextlib
2
import os
3
import warnings
4
from pathlib import Path
5
from types import MethodType
6
from typing import Optional, Union
7

8
import huggingface_hub
9
from transformers import (AutoTokenizer, PreTrainedTokenizer,
10
11
                          PreTrainedTokenizerFast)

12
from vllm.envs import VLLM_USE_MODELSCOPE
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.logger import init_logger
14
from vllm.lora.request import LoRARequest
15
from vllm.transformers_utils.tokenizers import MistralTokenizer
16
from vllm.transformers_utils.utils import check_gguf_file
17
from vllm.utils import make_async
18
19
20

logger = init_logger(__name__)

21
22
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
                     MistralTokenizer]
23

24

25
26
27
28
29
30
31
32
33
34
35
36
37
def decode_tokens(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def encode_tokens(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: Optional[bool] = None,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
    elif add_special_tokens is not None:
        return tokenizer.encode(text, add_special_tokens=add_special_tokens)
    return tokenizer.encode(text)


57
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
58
59
60
61
62
63
64
65
66
67
68
69
    """Get tokenizer with cached properties.

    This will patch the tokenizer object in place.

    By default, transformers will recompute multiple tokenizer properties
    each time they are called, leading to a significant slowdown. This
    function caches these properties for faster access."""

    tokenizer_all_special_ids = set(tokenizer.all_special_ids)
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
    tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
70
    tokenizer_len = len(tokenizer)
71

72
    max_token_id = max(tokenizer.get_vocab().values())
73
74
75
76
77
78
79
    # Some tokenizers (e.g., QwenTokenizer) have special tokens that
    # are added and included in the implementation of the vocab_size
    # property, but not in get_vocab(); if there is an implementation
    # of vocab size, we should take the greater value.
    if hasattr(tokenizer, "vocab_size"):
        with contextlib.suppress(NotImplementedError):
            max_token_id = max(max_token_id, tokenizer.vocab_size)
80

81
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
82
83
84
85
86
87
88
89
90
91
92
93
94

        @property
        def all_special_ids(self):
            return tokenizer_all_special_ids

        @property
        def all_special_tokens(self):
            return tokenizer_all_special_tokens

        @property
        def all_special_tokens_extended(self):
            return tokenizer_all_special_tokens_extended

95
96
97
98
        @property
        def max_token_id(self):
            return max_token_id

99
100
101
        def __len__(self):
            return tokenizer_len

102
103
104
105
106
107
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

    tokenizer.__class__ = CachedTokenizer
    return tokenizer


108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
    """Patch _pad method to accept `padding_side` for older tokenizers."""
    orig_pad = tokenizer._pad

    def _pad(
        self: PreTrainedTokenizer,
        *args,
        padding_side: Optional[str] = None,
        **kwargs,
    ):
        if padding_side is not None and padding_side != self.padding_side:
            msg = ("`padding_side` argument is not supported by "
                   f"{type(tokenizer).__name__} and will be ignored.")
            warnings.warn(msg, stacklevel=2)

        return orig_pad(*args, **kwargs)

    tokenizer._pad = MethodType(_pad, tokenizer)


128
def get_tokenizer(
129
    tokenizer_name: Union[str, Path],
130
    *args,
131
    tokenizer_mode: str = "auto",
132
    trust_remote_code: bool = False,
133
    revision: Optional[str] = None,
134
    download_dir: Optional[str] = None,
135
    **kwargs,
136
) -> AnyTokenizer:
137
138
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
    """
139
140
141
142
143
144
145
146
147
148
149
    if VLLM_USE_MODELSCOPE:
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
            tokenizer_path = snapshot_download(
                model_id=tokenizer_name,
                cache_dir=download_dir,
150
                revision=revision,
151
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
152
                # Ignore weights - we only need the tokenizer.
153
                ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
154
155
            tokenizer_name = tokenizer_path

156
157
158
159
160
161
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

162
163
164
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

165
    # Separate model folder from file path for GGUF models
166
    is_gguf = check_gguf_file(tokenizer_name)
167
168
169
170
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

171
172
173
174
175
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
            'It is strongly recommended to run mistral models with '
176
            '`--tokenizer-mode "mistral"` to ensure correct '
177
178
179
180
181
182
183
184
185
            'encoding and decoding.',
            FutureWarning,
            stacklevel=2)
    if tokenizer_mode == "mistral":
        tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
                                                     revision=revision)
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
186
187
188
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
189
                revision=revision,
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
                    "does not exist or is not currently imported." in str(e)
                    or "requires you to execute the tokenizer file" in str(e)):
                err_msg = ("Failed to load the tokenizer. If the tokenizer "
                           "is a custom tokenizer not yet available in the "
                           "HuggingFace transformers library, consider "
                           "setting `trust_remote_code=True` in LLM or using "
                           "the `--trust-remote-code` flag in the CLI.")
                raise RuntimeError(err_msg) from e
            else:
                raise e

208
209
210
211
        # NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
        if type(tokenizer).__name__ in ("ChatGLMTokenizer",
                                        "ChatGLM4Tokenizer"):
            assert isinstance(tokenizer, PreTrainedTokenizer)
212
            patch_padding_side(tokenizer)
213

214
215
216
217
218
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
                "slowdown. Consider using a fast tokenizer instead.")
        tokenizer = get_cached_tokenizer(tokenizer)
219

220
    return tokenizer
221
222


223
def get_lora_tokenizer(lora_request: LoRARequest, *args,
224
                       **kwargs) -> Optional[AnyTokenizer]:
225
226
227
    if lora_request is None:
        return None
    try:
228
        tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
229
    except Exception as e:
230
231
232
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
233
            "No tokenizer found in %s, using base model tokenizer instead. "
234
            "(Exception: %s)", lora_request.lora_path, e)
235
236
237
238
239
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)