tokenizer.py 7.6 KB
Newer Older
1
import os
2
import warnings
3
from pathlib import Path
4
from types import MethodType
5
from typing import Optional, Union
6

7
import huggingface_hub
8
from transformers import (AutoTokenizer, PreTrainedTokenizer,
9
10
                          PreTrainedTokenizerFast)

11
from vllm.envs import VLLM_USE_MODELSCOPE
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
15
from vllm.transformers_utils.tokenizers import (BaichuanTokenizer,
                                                MistralTokenizer)
16
from vllm.transformers_utils.utils import check_gguf_file
17
from vllm.utils import make_async
18
19
20

logger = init_logger(__name__)

21
22
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
                     MistralTokenizer]
23

24

25
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
26
27
28
29
30
31
32
33
34
35
36
37
    """Get tokenizer with cached properties.

    This will patch the tokenizer object in place.

    By default, transformers will recompute multiple tokenizer properties
    each time they are called, leading to a significant slowdown. This
    function caches these properties for faster access."""

    tokenizer_all_special_ids = set(tokenizer.all_special_ids)
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
    tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
38
    tokenizer_len = len(tokenizer)
39

40
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
41
42
43
44
45
46
47
48
49
50
51
52
53

        @property
        def all_special_ids(self):
            return tokenizer_all_special_ids

        @property
        def all_special_tokens(self):
            return tokenizer_all_special_tokens

        @property
        def all_special_tokens_extended(self):
            return tokenizer_all_special_tokens_extended

54
55
56
        def __len__(self):
            return tokenizer_len

57
58
59
60
61
62
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

    tokenizer.__class__ = CachedTokenizer
    return tokenizer


63
def get_tokenizer(
64
    tokenizer_name: Union[str, Path],
65
    *args,
66
    tokenizer_mode: str = "auto",
67
    trust_remote_code: bool = False,
68
    revision: Optional[str] = None,
69
    download_dir: Optional[str] = None,
70
    **kwargs,
71
) -> AnyTokenizer:
72
73
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
    """
74
75
76
77
78
79
80
81
82
83
84
    if VLLM_USE_MODELSCOPE:
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
            tokenizer_path = snapshot_download(
                model_id=tokenizer_name,
                cache_dir=download_dir,
85
                revision=revision,
86
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
87
                # Ignore weights - we only need the tokenizer.
88
                ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
89
90
            tokenizer_name = tokenizer_path

91
92
93
94
95
96
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

97
98
99
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

100
    # Separate model folder from file path for GGUF models
101
    is_gguf = check_gguf_file(tokenizer_name)
102
103
104
105
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
            'It is strongly recommended to run mistral models with '
            '`--tokenizer_mode "mistral"` to ensure correct '
            'encoding and decoding.',
            FutureWarning,
            stacklevel=2)
    if tokenizer_mode == "mistral":
        tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
                                                     revision=revision)
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
121
122
123
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
124
                revision=revision,
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
                    "does not exist or is not currently imported." in str(e)
                    or "requires you to execute the tokenizer file" in str(e)):
                err_msg = ("Failed to load the tokenizer. If the tokenizer "
                           "is a custom tokenizer not yet available in the "
                           "HuggingFace transformers library, consider "
                           "setting `trust_remote_code=True` in LLM or using "
                           "the `--trust-remote-code` flag in the CLI.")
                raise RuntimeError(err_msg) from e
            else:
                raise e
        except AttributeError as e:
            if "BaichuanTokenizer" in str(e):
                # This is for the error "'BaichuanTokenizer' object has no
                # attribute 'sp_model'".
                tokenizer = BaichuanTokenizer.from_pretrained(
                    tokenizer_name,
                    *args,
                    trust_remote_code=trust_remote_code,
                    revision=revision,
                    **kwargs,
                )
            else:
                raise e

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        # NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
        if type(tokenizer).__name__ in ("ChatGLMTokenizer",
                                        "ChatGLM4Tokenizer"):
            assert isinstance(tokenizer, PreTrainedTokenizer)
            orig_pad = tokenizer._pad

            # Patch _pad method to accept `padding_side`
            def _pad(
                self: PreTrainedTokenizer,
                *args,
                padding_side: Optional[str] = None,
                **kwargs,
            ):
                if (padding_side is not None
                        and padding_side != self.padding_side):
                    msg = ("`padding_side` argument is not supported by "
                           "ChatGLMTokenizer and will be ignored.")
                    warnings.warn(msg, stacklevel=2)

                return orig_pad(*args, **kwargs)

            tokenizer._pad = MethodType(_pad, tokenizer)

179
180
181
182
183
        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
                "slowdown. Consider using a fast tokenizer instead.")
        tokenizer = get_cached_tokenizer(tokenizer)
184

185
    return tokenizer
186
187


188
def get_lora_tokenizer(lora_request: LoRARequest, *args,
189
                       **kwargs) -> Optional[AnyTokenizer]:
190
191
192
    if lora_request is None:
        return None
    try:
193
        tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
194
    except Exception as e:
195
196
197
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
198
            "No tokenizer found in %s, using base model tokenizer instead. "
199
            "(Exception: %s)", lora_request.lora_path, e)
200
201
202
203
204
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)