tokenizer.py 6.64 KB
Newer Older
1
import os
2
import warnings
3
from pathlib import Path
4
from typing import Optional, Union
5

6
import huggingface_hub
7
from transformers import (AutoTokenizer, PreTrainedTokenizer,
8
9
                          PreTrainedTokenizerFast)

10
from vllm.envs import VLLM_USE_MODELSCOPE
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
14
from vllm.transformers_utils.tokenizers import (BaichuanTokenizer,
                                                MistralTokenizer)
15
from vllm.transformers_utils.utils import check_gguf_file
16
from vllm.utils import make_async
17
18
19

logger = init_logger(__name__)

20
21
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
                     MistralTokenizer]
22

23

24
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
25
26
27
28
29
30
31
32
33
34
35
36
    """Get tokenizer with cached properties.

    This will patch the tokenizer object in place.

    By default, transformers will recompute multiple tokenizer properties
    each time they are called, leading to a significant slowdown. This
    function caches these properties for faster access."""

    tokenizer_all_special_ids = set(tokenizer.all_special_ids)
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
    tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
37
    tokenizer_len = len(tokenizer)
38

39
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
40
41
42
43
44
45
46
47
48
49
50
51
52

        @property
        def all_special_ids(self):
            return tokenizer_all_special_ids

        @property
        def all_special_tokens(self):
            return tokenizer_all_special_tokens

        @property
        def all_special_tokens_extended(self):
            return tokenizer_all_special_tokens_extended

53
54
55
        def __len__(self):
            return tokenizer_len

56
57
58
59
60
61
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

    tokenizer.__class__ = CachedTokenizer
    return tokenizer


62
def get_tokenizer(
63
    tokenizer_name: Union[str, Path],
64
    *args,
65
    tokenizer_mode: str = "auto",
66
    trust_remote_code: bool = False,
67
    revision: Optional[str] = None,
68
    download_dir: Optional[str] = None,
69
    **kwargs,
70
) -> AnyTokenizer:
71
72
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
    """
73
74
75
76
77
78
79
80
81
82
83
    if VLLM_USE_MODELSCOPE:
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
            tokenizer_path = snapshot_download(
                model_id=tokenizer_name,
                cache_dir=download_dir,
84
                revision=revision,
85
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
86
                # Ignore weights - we only need the tokenizer.
87
                ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
88
89
            tokenizer_name = tokenizer_path

90
91
92
93
94
95
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

96
97
98
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

99
    # Separate model folder from file path for GGUF models
100
    is_gguf = check_gguf_file(tokenizer_name)
101
102
103
104
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    # if tokenizer is from official mistral org
    is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
    if is_from_mistral_org and tokenizer_mode != "mistral":
        warnings.warn(
            'It is strongly recommended to run mistral models with '
            '`--tokenizer_mode "mistral"` to ensure correct '
            'encoding and decoding.',
            FutureWarning,
            stacklevel=2)
    if tokenizer_mode == "mistral":
        tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
                                                     revision=revision)
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
120
121
122
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
123
                revision=revision,
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
                **kwargs,
            )
        except ValueError as e:
            # If the error pertains to the tokenizer class not existing or not
            # currently being imported,
            # suggest using the --trust-remote-code flag.
            if not trust_remote_code and (
                    "does not exist or is not currently imported." in str(e)
                    or "requires you to execute the tokenizer file" in str(e)):
                err_msg = ("Failed to load the tokenizer. If the tokenizer "
                           "is a custom tokenizer not yet available in the "
                           "HuggingFace transformers library, consider "
                           "setting `trust_remote_code=True` in LLM or using "
                           "the `--trust-remote-code` flag in the CLI.")
                raise RuntimeError(err_msg) from e
            else:
                raise e
        except AttributeError as e:
            if "BaichuanTokenizer" in str(e):
                # This is for the error "'BaichuanTokenizer' object has no
                # attribute 'sp_model'".
                tokenizer = BaichuanTokenizer.from_pretrained(
                    tokenizer_name,
                    *args,
                    trust_remote_code=trust_remote_code,
                    revision=revision,
                    **kwargs,
                )
            else:
                raise e

        if not isinstance(tokenizer, PreTrainedTokenizerFast):
            logger.warning(
                "Using a slow tokenizer. This might cause a significant "
                "slowdown. Consider using a fast tokenizer instead.")
        tokenizer = get_cached_tokenizer(tokenizer)
160

161
    return tokenizer
162
163


164
def get_lora_tokenizer(lora_request: LoRARequest, *args,
165
                       **kwargs) -> Optional[AnyTokenizer]:
166
167
168
    if lora_request is None:
        return None
    try:
169
        tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
170
171
172
173
    except OSError as e:
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
174
            "No tokenizer found in %s, using base model tokenizer instead. "
175
            "(Exception: %s)", lora_request.lora_path, e)
176
177
178
179
180
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)