"tests/kernels/attention/test_cache.py" did not exist on "74d8d77626763bf7c4a2dd227231c69bb4638e29"
tokenizer.py 3.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import warnings
5
from functools import lru_cache
6
from typing import TYPE_CHECKING, Any
7

8
from typing_extensions import assert_never
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
from vllm.logger import init_logger
11
from vllm.tokenizers import TokenizerLike, get_tokenizer
12

13
14
if TYPE_CHECKING:
    from vllm.config import ModelConfig
15

16

17
18
logger = init_logger(__name__)

19
20
21
22
23
24
25
26
27
28
29
30

def __getattr__(name: str):
    if name == "AnyTokenizer":
        warnings.warn(
            "`vllm.transformers_utils.tokenizer.AnyTokenizer` has been moved to "
            "`vllm.tokenizers.TokenizerLike`. "
            "The old name will be removed in v0.13.",
            DeprecationWarning,
            stacklevel=2,
        )

        return TokenizerLike
31
32
33
34
35
36
37
38
39
40
41
42
    if name == "get_cached_tokenizer":
        from vllm.tokenizers.hf import get_cached_tokenizer

        warnings.warn(
            "`vllm.transformers_utils.tokenizer.get_cached_tokenizer` "
            "has been moved to `vllm.tokenizers.hf.get_cached_tokenizer`. "
            "The old name will be removed in v0.13.",
            DeprecationWarning,
            stacklevel=2,
        )

        return get_cached_tokenizer
43
44

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
45

46

47
def decode_tokens(
48
    tokenizer: TokenizerLike,
49
50
    token_ids: list[int],
    *,
51
    skip_special_tokens: bool | None = None,
52
53
54
) -> str:
    """
    Backend-agnostic equivalent of HF's
55
    `tokenizer.decode(token_ids, ...)`.
56

57
    `skip_special_tokens=None` means to use the backend's default
58
    settings.
59
    """
60
61
    kw_args: dict[str, Any] = {}

62
    if skip_special_tokens is not None:
63
        kw_args["skip_special_tokens"] = skip_special_tokens
64

65
    return tokenizer.decode(token_ids, **kw_args)
66
67


68
def encode_tokens(
69
    tokenizer: TokenizerLike,
70
71
    text: str,
    *,
72
73
74
    truncation: bool | None = None,
    max_length: int | None = None,
    add_special_tokens: bool | None = None,
75
76
77
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
78
    `tokenizer.encode(text, ...)`.
79

80
    `add_special_tokens=None` means to use the backend's default
81
    settings.
82
    """
83
84
85
86
87
88
89
90

    kw_args: dict[str, Any] = {}
    if max_length is not None:
        kw_args["max_length"] = max_length

    if truncation is not None:
        kw_args["truncation"] = truncation

91
    if add_special_tokens is not None:
92
        kw_args["add_special_tokens"] = add_special_tokens
93

94
    return tokenizer.encode(text, **kw_args)
95
96


97
98
99
100
cached_get_tokenizer = lru_cache(get_tokenizer)


def cached_tokenizer_from_config(
101
    model_config: "ModelConfig",
102
103
104
105
106
    **kwargs: Any,
):
    return cached_get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
107
        revision=model_config.tokenizer_revision,
108
109
110
111
112
        trust_remote_code=model_config.trust_remote_code,
        **kwargs,
    )


113
def init_tokenizer_from_configs(model_config: "ModelConfig"):
114
115
116
117
118
119
120
    runner_type = model_config.runner_type
    if runner_type == "generate" or runner_type == "draft":
        truncation_side = "left"
    elif runner_type == "pooling":
        truncation_side = "right"
    else:
        assert_never(runner_type)
121

122
123
124
125
126
127
128
    return get_tokenizer(
        model_config.tokenizer,
        tokenizer_mode=model_config.tokenizer_mode,
        trust_remote_code=model_config.trust_remote_code,
        revision=model_config.tokenizer_revision,
        truncation_side=truncation_side,
    )