registry.py 8.04 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import dataclass, field
4
from functools import lru_cache
5
from pathlib import Path
6
from typing import TYPE_CHECKING
7
8

import huggingface_hub
9
from typing_extensions import TypeVar, assert_never
10
11
12

import vllm.envs as envs
from vllm.logger import init_logger
13
from vllm.transformers_utils.gguf_utils import (
14
    check_gguf_file,
15
    get_gguf_file_path_from_hf,
16
17
18
19
    is_gguf,
    is_remote_gguf,
    split_remote_gguf,
)
20
21
22
23
from vllm.transformers_utils.repo_utils import (
    any_pattern_in_repo_files,
    is_mistral_model_repo,
)
24
from vllm.utils.import_utils import resolve_obj_by_qualname
25
26
27

from .protocol import TokenizerLike

28
if TYPE_CHECKING:
29
    from vllm.config.model import ModelConfig, RunnerType
30

31
32
33
logger = init_logger(__name__)


34
_VLLM_TOKENIZERS = {
35
    "deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
Bijaya Dangol's avatar
Bijaya Dangol committed
36
    "grok2": ("grok2", "Grok2Tokenizer"),
37
    "hf": ("hf", "CachedHfTokenizer"),
38
    "kimi_audio": ("kimi_audio", "KimiAudioTokenizer"),
39
    "mistral": ("mistral", "MistralTokenizer"),
40
    "qwen_vl": ("qwen_vl", "QwenVLTokenizer"),
41
}
42
43


44
45
46
47
@dataclass
class _TokenizerRegistry:
    # Tokenizer mode ->  (tokenizer module, tokenizer class)
    tokenizers: dict[str, tuple[str, str]] = field(default_factory=dict)
48

49
50
    def register(self, tokenizer_mode: str, module: str, class_name: str) -> None:
        if tokenizer_mode in self.tokenizers:
51
52
53
54
55
56
57
58
            logger.warning(
                "%s.%s is already registered for tokenizer_mode=%r. "
                "It is overwritten by the new one.",
                module,
                class_name,
                tokenizer_mode,
            )

59
        self.tokenizers[tokenizer_mode] = (module, class_name)
60
61
62

        return None

63
64
    def load_tokenizer_cls(self, tokenizer_mode: str) -> type[TokenizerLike]:
        if tokenizer_mode not in self.tokenizers:
65
66
            raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.")

67
        module, class_name = self.tokenizers[tokenizer_mode]
68
        logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}")
69

70
        return resolve_obj_by_qualname(f"{module}.{class_name}")
71

72
73
74
    def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike:
        tokenizer_cls = self.load_tokenizer_cls(tokenizer_mode)
        return tokenizer_cls.from_pretrained(*args, **kwargs)
75

76
77
78
79
80
81
82
83
84
85

TokenizerRegistry = _TokenizerRegistry(
    {
        mode: (f"vllm.tokenizers.{mod_relname}", cls_name)
        for mode, (mod_relname, cls_name) in _VLLM_TOKENIZERS.items()
    }
)


def resolve_tokenizer_args(
86
87
    tokenizer_name: str | Path,
    *args,
88
    runner_type: "RunnerType" = "generate",
89
90
    tokenizer_mode: str = "auto",
    **kwargs,
91
92
93
94
):
    revision: str | None = kwargs.get("revision")
    download_dir: str | None = kwargs.get("download_dir")

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    if envs.VLLM_USE_MODELSCOPE:
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        from modelscope.hub.snapshot_download import snapshot_download

        # avoid circular import
        from vllm.model_executor.model_loader.weight_utils import get_lock

        # Only set the tokenizer here, model will be downloaded on the workers.
        if not Path(tokenizer_name).exists():
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=str(tokenizer_name),
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
                )
                tokenizer_name = tokenizer_path

    # Separate model folder from file path for GGUF models
    if is_gguf(tokenizer_name):
        if check_gguf_file(tokenizer_name):
            kwargs["gguf_file"] = Path(tokenizer_name).name
            tokenizer_name = Path(tokenizer_name).parent
        elif is_remote_gguf(tokenizer_name):
            tokenizer_name, quant_type = split_remote_gguf(tokenizer_name)
            # Get the HuggingFace Hub path for the GGUF file
            gguf_file = get_gguf_file_path_from_hf(
                tokenizer_name,
                quant_type,
                revision=revision,
            )
            kwargs["gguf_file"] = gguf_file

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    if "truncation_side" not in kwargs:
        if runner_type == "generate" or runner_type == "draft":
            kwargs["truncation_side"] = "left"
        elif runner_type == "pooling":
            kwargs["truncation_side"] = "right"
        else:
            assert_never(runner_type)

    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")

        tokenizer_mode = "hf"
        kwargs["use_fast"] = False

148
    # Try to use official Mistral tokenizer if possible
149
150
151
152
153
154
    if (
        tokenizer_mode == "auto"
        and is_mistral_model_repo(
            model_name_or_path=str(tokenizer_name), revision=revision
        )
        and any_pattern_in_repo_files(
155
            model_name_or_path=str(tokenizer_name),
156
            allow_patterns=["tekken.json", "tokenizer.model.v*"],
157
158
            revision=revision,
        )
159
160
    ):
        tokenizer_mode = "mistral"
161

Bijaya Dangol's avatar
Bijaya Dangol committed
162
    # Try to use Grok2 tiktoken tokenizer if possible
163
164
165
166
167
168
    if tokenizer_mode == "auto" and any_pattern_in_repo_files(
        model_name_or_path=str(tokenizer_name),
        allow_patterns=["tokenizer.tok.json"],
        revision=revision,
    ):
        tokenizer_mode = "grok2"
Bijaya Dangol's avatar
Bijaya Dangol committed
169

170
171
172
173
    # Model-specific tokenizers
    if tokenizer_mode == "auto" and "/Qwen-VL" in str(tokenizer_name):
        tokenizer_mode = "qwen_vl"

174
175
176
177
    # Fallback to HF tokenizer
    if tokenizer_mode == "auto":
        tokenizer_mode = "hf"

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    return tokenizer_mode, tokenizer_name, args, kwargs


cached_resolve_tokenizer_args = lru_cache(resolve_tokenizer_args)


def tokenizer_args_from_config(config: "ModelConfig", **kwargs):
    return cached_resolve_tokenizer_args(
        config.tokenizer,
        runner_type=config.runner_type,
        tokenizer_mode=config.tokenizer_mode,
        revision=config.tokenizer_revision,
        trust_remote_code=config.trust_remote_code,
        **kwargs,
    )


_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)


def get_tokenizer(
    tokenizer_name: str | Path,
    *args,
    tokenizer_cls: type[_T] = TokenizerLike,  # type: ignore[assignment]
    trust_remote_code: bool = False,
    revision: str | None = None,
    download_dir: str | None = None,
    **kwargs,
) -> _T:
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
    tokenizer_mode, tokenizer_name, args, kwargs = cached_resolve_tokenizer_args(
        tokenizer_name,
        *args,
211
212
213
214
215
216
        trust_remote_code=trust_remote_code,
        revision=revision,
        download_dir=download_dir,
        **kwargs,
    )

217
218
219
220
    if tokenizer_cls == TokenizerLike:
        tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode)
    else:
        tokenizer_cls_ = tokenizer_cls
221

222
    tokenizer = tokenizer_cls_.from_pretrained(tokenizer_name, *args, **kwargs)
223
224
225
226
227
228
    if not tokenizer.is_fast:
        logger.warning(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead."
        )

229
    return tokenizer  # type: ignore
230
231
232
233
234


cached_get_tokenizer = lru_cache(get_tokenizer)


235
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
236
237
238
    if model_config.skip_tokenizer_init:
        return None

239
    return cached_get_tokenizer(
240
        model_config.tokenizer,
241
        runner_type=model_config.runner_type,
242
243
244
        tokenizer_mode=model_config.tokenizer_mode,
        revision=model_config.tokenizer_revision,
        trust_remote_code=model_config.trust_remote_code,
245
246
        **kwargs,
    )