tokenizer.py 7.92 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import importlib.util
5
import os
6
import warnings
7
from functools import lru_cache
8
from pathlib import Path
9
from typing import TYPE_CHECKING, Any
10

11
import huggingface_hub
12
from typing_extensions import assert_never
13

14
from vllm import envs
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from vllm.logger import init_logger
16
17
18
19
20
21
from vllm.tokenizers import (
    HfTokenizer,
    MistralTokenizer,
    TokenizerLike,
    TokenizerRegistry,
)
22
23
24
25

from .gguf_utils import get_gguf_file_path_from_hf
from .repo_utils import list_filtered_repo_files
from .utils import check_gguf_file, is_gguf, is_remote_gguf, split_remote_gguf
26

27
28
if TYPE_CHECKING:
    from vllm.config import ModelConfig
29

30

31
32
logger = init_logger(__name__)

33
34
35
36
37
38
39
40
41
42
43
44

def __getattr__(name: str):
    if name == "AnyTokenizer":
        warnings.warn(
            "`vllm.transformers_utils.tokenizer.AnyTokenizer` has been moved to "
            "`vllm.tokenizers.TokenizerLike`. "
            "The old name will be removed in v0.13.",
            DeprecationWarning,
            stacklevel=2,
        )

        return TokenizerLike
45
46
47
48
49
50
51
52
53
54
55
56
    if name == "get_cached_tokenizer":
        from vllm.tokenizers.hf import get_cached_tokenizer

        warnings.warn(
            "`vllm.transformers_utils.tokenizer.get_cached_tokenizer` "
            "has been moved to `vllm.tokenizers.hf.get_cached_tokenizer`. "
            "The old name will be removed in v0.13.",
            DeprecationWarning,
            stacklevel=2,
        )

        return get_cached_tokenizer
57
58

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
59

60

61
def decode_tokens(
62
    tokenizer: TokenizerLike,
63
64
    token_ids: list[int],
    *,
65
    skip_special_tokens: bool | None = None,
66
67
68
) -> str:
    """
    Backend-agnostic equivalent of HF's
69
    `tokenizer.decode(token_ids, ...)`.
70

71
    `skip_special_tokens=None` means to use the backend's default
72
    settings.
73
    """
74
75
    kw_args: dict[str, Any] = {}

76
    if skip_special_tokens is not None:
77
        kw_args["skip_special_tokens"] = skip_special_tokens
78

79
    return tokenizer.decode(token_ids, **kw_args)
80
81


82
def encode_tokens(
83
    tokenizer: TokenizerLike,
84
85
    text: str,
    *,
86
87
88
    truncation: bool | None = None,
    max_length: int | None = None,
    add_special_tokens: bool | None = None,
89
90
91
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
92
    `tokenizer.encode(text, ...)`.
93

94
    `add_special_tokens=None` means to use the backend's default
95
    settings.
96
    """
97
98
99
100
101
102
103
104

    kw_args: dict[str, Any] = {}
    if max_length is not None:
        kw_args["max_length"] = max_length

    if truncation is not None:
        kw_args["truncation"] = truncation

105
    if add_special_tokens is not None:
106
        kw_args["add_special_tokens"] = add_special_tokens
107

108
    return tokenizer.encode(text, **kw_args)
109
110


111
def get_tokenizer(
112
    tokenizer_name: str | Path,
113
    *args,
114
    tokenizer_mode: str = "auto",
115
    trust_remote_code: bool = False,
116
117
    revision: str | None = None,
    download_dir: str | None = None,
118
    **kwargs,
119
) -> TokenizerLike:
120
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
121
    if envs.VLLM_USE_MODELSCOPE:
122
123
124
125
126
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

127
128
129
        # avoid circuit import
        from vllm.model_executor.model_loader.weight_utils import get_lock

130
131
        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
132
133
134
135
136
137
138
139
140
            # Use file lock to prevent multiple processes from
            # downloading the same file at the same time.
            with get_lock(tokenizer_name, download_dir):
                tokenizer_path = snapshot_download(
                    model_id=tokenizer_name,
                    cache_dir=download_dir,
                    revision=revision,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    # Ignore weights - we only need the tokenizer.
141
142
                    ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
                )
143
                tokenizer_name = tokenizer_path
144

145
146
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
147
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
148
149
        kwargs["use_fast"] = False

150
151
152
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

153
    # Separate model folder from file path for GGUF models
154
155
156
157
158
    if is_gguf(tokenizer_name):
        if check_gguf_file(tokenizer_name):
            kwargs["gguf_file"] = Path(tokenizer_name).name
            tokenizer_name = Path(tokenizer_name).parent
        elif is_remote_gguf(tokenizer_name):
159
160
161
162
163
164
165
166
            tokenizer_name, quant_type = split_remote_gguf(tokenizer_name)
            # Get the HuggingFace Hub path for the GGUF file
            gguf_file = get_gguf_file_path_from_hf(
                tokenizer_name,
                quant_type,
                revision=revision,
            )
            kwargs["gguf_file"] = gguf_file
167

168
169
170
171
172
173
174
175
176
    # if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
    # first to use official Mistral tokenizer if possible.
    mistral_common_installed = importlib.util.find_spec("mistral_common") is not None
    if tokenizer_mode == "auto" and mistral_common_installed:
        allow_patterns = ["tekken.json", "tokenizer.model.v*"]
        files_list = list_filtered_repo_files(
            model_name_or_path=str(tokenizer_name),
            allow_patterns=allow_patterns,
            revision=revision,
177
        )
178
179
        if len(files_list) > 0:
            tokenizer_mode = "mistral"
180

181
    tokenizer: TokenizerLike
182
    if tokenizer_mode == "mistral":
183
        logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}")
184
        tokenizer = MistralTokenizer.from_pretrained(
185
186
187
188
189
190
            tokenizer_name,
            *args,
            trust_remote_code=trust_remote_code,
            revision=revision,
            download_dir=download_dir,
            **kwargs,
191
        )
192
    elif tokenizer_mode == "custom":
193
        logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}")
194
195
196
        tokenizer = TokenizerRegistry.get_tokenizer(
            str(tokenizer_name),
            *args,
197
            trust_remote_code=trust_remote_code,
198
199
200
201
            revision=revision,
            download_dir=download_dir,
            **kwargs,
        )
202
    else:
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        logger.debug_once(f"Loading HfTokenizer from {tokenizer_name}")
        tokenizer = HfTokenizer.from_pretrained(
            tokenizer_name,
            *args,
            trust_remote_code=trust_remote_code,
            revision=revision,
            download_dir=download_dir,
            **kwargs,
        )

    if not tokenizer.is_fast:
        logger.warning(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead."
217
        )
218

219
    return tokenizer
220
221


222
223
224
225
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
226
    model_config: "ModelConfig",
227
228
229
230
231
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
232
        revision=model_config.tokenizer_revision,
233
234
235
236
237
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


238
def init_tokenizer_from_configs(model_config: "ModelConfig"):
239
240
241
242
243
244
245
    runner_type = model_config.runner_type
    if runner_type == "generate" or runner_type == "draft":
        truncation_side = "left"
    elif runner_type == "pooling":
        truncation_side = "right"
    else:
        assert_never(runner_type)
246

247
248
249
250
251
252
253
    return get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.tokenizer_revision,
        truncation_side=truncation_side,
    )