selector.py 7.02 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import inspect
5
import os
6
from collections.abc import Generator
7
from contextlib import contextmanager
8
from functools import cache
9
from typing import cast, get_args
10
11
12

import torch

13
import vllm.envs as envs
14
from vllm.attention.backends.abstract import AttentionBackend
15
16
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.cache import CacheDType
17
from vllm.logger import init_logger
18
19
from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils.import_utils import resolve_obj_by_qualname
20
21
22
23

logger = init_logger(__name__)


24
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
25
    """
26
27
28
29
30
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

31
    * AttentionBackendEnum value if an override is specified
32
    * None otherwise
33
    """
34
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
35
    return None if backend_name is None else AttentionBackendEnum[backend_name]
36
37
38
39
40
41
42
43


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
44
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
45
forced_attn_backend: AttentionBackendEnum | None = None
46
47


48
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
49
    """
50
51
52
53
54
55
56
57
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
58
    """
59
60
61
62
    global forced_attn_backend
    forced_attn_backend = attn_backend


63
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
64
    """
65
66
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
67
    """
68
69
70
    return forced_attn_backend


71
72
73
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
74
    kv_cache_dtype: str | None,
75
    block_size: int | None,
76
    use_mla: bool = False,
77
    has_sink: bool = False,
78
    use_sparse: bool = False,
79
) -> type[AttentionBackend]:
80
    """Selects which attention backend to use and lazily imports it."""
81
82
83
84
85
86
87
88

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

Joe Runde's avatar
Joe Runde committed
89
90
91
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
92
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
93
        block_size=block_size,
94
        use_mla=use_mla,
95
        has_sink=has_sink,
96
        use_sparse=use_sparse,
Joe Runde's avatar
Joe Runde committed
97
98
99
    )


100
@cache
Joe Runde's avatar
Joe Runde committed
101
102
103
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
104
105
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
106
    use_mla: bool = False,
107
    has_sink: bool = False,
108
    use_sparse: bool = False,
109
) -> type[AttentionBackend]:
110
111
112
113
114
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
115
    selected_backend = None
116
117
118
    backend_by_global_setting: AttentionBackendEnum | None = (
        get_global_forced_attn_backend()
    )
119
120
121
122
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
123
        backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
124
        if backend_by_env_var is not None:
125
126
127
128
129
            if backend_by_env_var.endswith("_VLLM_V1"):
                logger.warning(
                    "The suffix '_VLLM_V1' in the environment variable "
                    "%s is no longer necessary as V0 backends have been "
                    "deprecated. Please remove this suffix from your "
130
131
132
133
                    "environment variable setting.",
                    STR_BACKEND_ENV_VAR,
                )
                backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
134
135
136
            try:
                selected_backend = AttentionBackendEnum[backend_by_env_var]
            except KeyError as e:
137
                raise ValueError(
138
139
140
                    f"Invalid attention backend: '{backend_by_env_var}'. Valid "
                    f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
                ) from e
141

142
    # get device-specific attn_backend
143
144
    from vllm.platforms import current_platform

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    sig = inspect.signature(current_platform.get_attn_backend_cls)
    if "use_v1" in sig.parameters:
        logger.warning_once(
            "use_v1 parameter for get_attn_backend_cls is deprecated and will "
            "be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
            "remove it from your plugin code."
        )
        attention_cls = current_platform.get_attn_backend_cls(
            selected_backend,
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            True,  # use_v1
            use_mla,
            has_sink,
            use_sparse,
        )
    else:
        attention_cls = current_platform.get_attn_backend_cls(
            selected_backend,
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
        )
174
175
    if not attention_cls:
        raise ValueError(
176
177
            f"Invalid attention backend for {current_platform.device_name}"
        )
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
193
194
195
196


@contextmanager
def global_force_attn_backend_context_manager(
197
    attn_backend: AttentionBackendEnum,
198
199
) -> Generator[None, None, None]:
    """
200
201
202
203
204
205
206
207
208
209
210
211
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
212
    """
213
214
215
216
217
218
219
220
221
222
223
224
225

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)
226
        _cached_get_attn_backend.cache_clear()