registry.py 9.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
from dataclasses import dataclass, field
5
from functools import lru_cache
6
from pathlib import Path
7
from typing import TYPE_CHECKING
8
9

import huggingface_hub
10
from typing_extensions import TypeVar, assert_never
11
12
13

import vllm.envs as envs
from vllm.logger import init_logger
14
from vllm.transformers_utils.config import get_config
15
from vllm.transformers_utils.gguf_utils import (
16
    check_gguf_file,
17
    get_gguf_file_path_from_hf,
18
19
20
21
    is_gguf,
    is_remote_gguf,
    split_remote_gguf,
)
22
23
24
25
from vllm.transformers_utils.repo_utils import (
    any_pattern_in_repo_files,
    is_mistral_model_repo,
)
26
from vllm.utils.import_utils import resolve_obj_by_qualname
27
28
29

from .protocol import TokenizerLike

30
if TYPE_CHECKING:
31
    from vllm.config.model import ModelConfig, RunnerType
32

33
34
35
logger = init_logger(__name__)


36
37
38
39
40
41
42
# Model types whose hub tokenizer_class is incorrect and should be overridden with
# TokenizersBackend (the generic fast tokenizer). Adding a model type here is always a
# temporary workaround and better long term solutions are:
# - Add model type to MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS in transformers (better)
# - Fix tokenizer_class on the hub for the affected models (best)
_MODEL_TYPES_WITH_INCORRECT_TOKENIZER_CLASS: set[str] = {"step3_vl"}

43
_VLLM_TOKENIZERS = {
44
    "deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
Bijaya Dangol's avatar
Bijaya Dangol committed
45
    "grok2": ("grok2", "Grok2Tokenizer"),
46
    "hf": ("hf", "CachedHfTokenizer"),
47
    "kimi_audio": ("kimi_audio", "KimiAudioTokenizer"),
48
    "mistral": ("mistral", "MistralTokenizer"),
49
    "qwen_vl": ("qwen_vl", "QwenVLTokenizer"),
50
}
51
52


53
54
55
56
@dataclass
class _TokenizerRegistry:
    # Tokenizer mode ->  (tokenizer module, tokenizer class)
    tokenizers: dict[str, tuple[str, str]] = field(default_factory=dict)
57

58
59
    def register(self, tokenizer_mode: str, module: str, class_name: str) -> None:
        if tokenizer_mode in self.tokenizers:
60
61
62
63
64
65
66
67
            logger.warning(
                "%s.%s is already registered for tokenizer_mode=%r. "
                "It is overwritten by the new one.",
                module,
                class_name,
                tokenizer_mode,
            )

68
        self.tokenizers[tokenizer_mode] = (module, class_name)
69
70
71

        return None

72
73
    def load_tokenizer_cls(self, tokenizer_mode: str) -> type[TokenizerLike]:
        if tokenizer_mode not in self.tokenizers:
74
75
            raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.")

76
        module, class_name = self.tokenizers[tokenizer_mode]
77
        logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}")
78

79
        return resolve_obj_by_qualname(f"{module}.{class_name}")
80

81
82
83
    def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike:
        tokenizer_cls = self.load_tokenizer_cls(tokenizer_mode)
        return tokenizer_cls.from_pretrained(*args, **kwargs)
84

85
86
87
88
89
90
91
92
93
94

TokenizerRegistry = _TokenizerRegistry(
    {
        mode: (f"vllm.tokenizers.{mod_relname}", cls_name)
        for mode, (mod_relname, cls_name) in _VLLM_TOKENIZERS.items()
    }
)


def resolve_tokenizer_args(
95
96
    tokenizer_name: str | Path,
    *args,
97
    runner_type: "RunnerType" = "generate",
98
99
    tokenizer_mode: str = "auto",
    **kwargs,
100
101
102
103
):
    revision: str | None = kwargs.get("revision")
    download_dir: str | None = kwargs.get("download_dir")

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    if envs.VLLM_USE_MODELSCOPE:
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        from modelscope.hub.snapshot_download import snapshot_download

        # avoid circular import
        from vllm.model_executor.model_loader.weight_utils import get_lock

        # Only set the tokenizer here, model will be downloaded on the workers.
        if not Path(tokenizer_name).exists():
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=str(tokenizer_name),
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
                )
                tokenizer_name = tokenizer_path

    # Separate model folder from file path for GGUF models
    if is_gguf(tokenizer_name):
        if check_gguf_file(tokenizer_name):
            kwargs["gguf_file"] = Path(tokenizer_name).name
            tokenizer_name = Path(tokenizer_name).parent
        elif is_remote_gguf(tokenizer_name):
            tokenizer_name, quant_type = split_remote_gguf(tokenizer_name)
            # Get the HuggingFace Hub path for the GGUF file
            gguf_file = get_gguf_file_path_from_hf(
                tokenizer_name,
                quant_type,
                revision=revision,
            )
            kwargs["gguf_file"] = gguf_file

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    if "truncation_side" not in kwargs:
        if runner_type == "generate" or runner_type == "draft":
            kwargs["truncation_side"] = "left"
        elif runner_type == "pooling":
            kwargs["truncation_side"] = "right"
        else:
            assert_never(runner_type)

    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")

        tokenizer_mode = "hf"
        kwargs["use_fast"] = False

157
    # Try to use official Mistral tokenizer if possible
158
159
160
161
162
163
    if (
        tokenizer_mode == "auto"
        and is_mistral_model_repo(
            model_name_or_path=str(tokenizer_name), revision=revision
        )
        and any_pattern_in_repo_files(
164
            model_name_or_path=str(tokenizer_name),
165
            allow_patterns=["tekken.json", "tokenizer.model.v*"],
166
167
            revision=revision,
        )
168
169
    ):
        tokenizer_mode = "mistral"
170
171
172
173
174

    # Fallback to HF tokenizer
    if tokenizer_mode == "auto":
        tokenizer_mode = "hf"

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    return tokenizer_mode, tokenizer_name, args, kwargs


cached_resolve_tokenizer_args = lru_cache(resolve_tokenizer_args)


def tokenizer_args_from_config(config: "ModelConfig", **kwargs):
    return cached_resolve_tokenizer_args(
        config.tokenizer,
        runner_type=config.runner_type,
        tokenizer_mode=config.tokenizer_mode,
        revision=config.tokenizer_revision,
        trust_remote_code=config.trust_remote_code,
        **kwargs,
    )


_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)


def get_tokenizer(
    tokenizer_name: str | Path,
    *args,
    tokenizer_cls: type[_T] = TokenizerLike,  # type: ignore[assignment]
    trust_remote_code: bool = False,
    revision: str | None = None,
    download_dir: str | None = None,
    **kwargs,
) -> _T:
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
    tokenizer_mode, tokenizer_name, args, kwargs = cached_resolve_tokenizer_args(
        tokenizer_name,
        *args,
208
209
210
211
212
213
        trust_remote_code=trust_remote_code,
        revision=revision,
        download_dir=download_dir,
        **kwargs,
    )

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    # Ensure that, if the config were to come from vllm.transformers_utils.config, it is
    # registered with AutoConfig before the tokenizer is loaded. This is necessary since
    # tokenizer_cls_.from_pretrained will call AutoConfig.from_pretrained internally.
    # This may fail for paths that don't have a model config (e.g. LoRA adapters),
    # which is fine — those don't need custom config registration.
    config = None
    with contextlib.suppress(ValueError, OSError):
        config = get_config(
            tokenizer_name,
            trust_remote_code=trust_remote_code,
            revision=revision,
        )

    # Some models have an incorrect tokenizer_class on the hub.
    # For these model types, bypass AutoTokenizer and use TokenizersBackend directly.
    model_type = getattr(config, "model_type", None) if config else None
    if model_type in _MODEL_TYPES_WITH_INCORRECT_TOKENIZER_CLASS:
        from transformers.tokenization_utils_tokenizers import TokenizersBackend

        logger.debug(
            "Overriding tokenizer_class to TokenizersBackend for model_type=%r",
            model_type,
        )
        tokenizer_cls_ = TokenizersBackend
    elif tokenizer_cls == TokenizerLike:
239
240
241
        tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode)
    else:
        tokenizer_cls_ = tokenizer_cls
242

243
    tokenizer = tokenizer_cls_.from_pretrained(tokenizer_name, *args, **kwargs)
244
245
246
247
248
249
    if not tokenizer.is_fast:
        logger.warning(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead."
        )

250
    return tokenizer  # type: ignore
251
252
253
254
255


cached_get_tokenizer = lru_cache(get_tokenizer)


256
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
257
258
259
    if model_config.skip_tokenizer_init:
        return None

260
    return cached_get_tokenizer(
261
        model_config.tokenizer,
262
        runner_type=model_config.runner_type,
263
264
265
        tokenizer_mode=model_config.tokenizer_mode,
        revision=model_config.tokenizer_revision,
        trust_remote_code=model_config.trust_remote_code,
266
267
        **kwargs,
    )