selector.py 8.27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import inspect
5
import os
6
from collections.abc import Generator
7
from contextlib import contextmanager
8
from functools import cache
9
from typing import cast, get_args
10
11
12

import torch

13
import vllm.envs as envs
14
from vllm.attention.backends.abstract import AttentionBackend
15
16
17
18
19
from vllm.attention.backends.registry import (
    MAMBA_TYPE_TO_BACKEND_MAP,
    AttentionBackendEnum,
    MambaAttentionBackendEnum,
)
20
from vllm.config.cache import CacheDType
21
from vllm.logger import init_logger
22
23
from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils.import_utils import resolve_obj_by_qualname
24
25
26
27

logger = init_logger(__name__)


28
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
29
    """
30
31
32
33
34
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

35
    * AttentionBackendEnum value if an override is specified
36
    * None otherwise
37
    """
38
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
39
40
41
42
43
44
45
46
    if backend_name is None:
        return None
    if backend_name == "XFORMERS":
        raise ValueError(
            "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
            "details). Please select a supported attention backend."
        )
    return AttentionBackendEnum[backend_name]
47
48
49
50
51
52
53
54


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
55
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
56
forced_attn_backend: AttentionBackendEnum | None = None
57
58


59
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
60
    """
61
62
63
64
65
66
67
68
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
69
    """
70
71
72
73
    global forced_attn_backend
    forced_attn_backend = attn_backend


74
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
75
    """
76
77
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
78
    """
79
80
81
    return forced_attn_backend


82
83
84
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
85
    kv_cache_dtype: str | None,
86
    block_size: int | None,
87
    use_mla: bool = False,
88
    has_sink: bool = False,
89
    use_sparse: bool = False,
90
    attn_type: str | None = None,
91
) -> type[AttentionBackend]:
92
    """Selects which attention backend to use and lazily imports it."""
93
94
95
96
97
98
99
100

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

Joe Runde's avatar
Joe Runde committed
101
102
103
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
104
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
105
        block_size=block_size,
106
        use_mla=use_mla,
107
        has_sink=has_sink,
108
        use_sparse=use_sparse,
109
        attn_type=attn_type,
Joe Runde's avatar
Joe Runde committed
110
111
112
    )


113
@cache
Joe Runde's avatar
Joe Runde committed
114
115
116
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
117
118
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
119
    use_mla: bool = False,
120
    has_sink: bool = False,
121
    use_sparse: bool = False,
122
    attn_type: str | None = None,
123
) -> type[AttentionBackend]:
124
125
126
127
128
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
129
    selected_backend = None
130
131
132
    backend_by_global_setting: AttentionBackendEnum | None = (
        get_global_forced_attn_backend()
    )
133
134
135
136
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
137
        backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
138
        if backend_by_env_var is not None:
139
140
141
142
143
            if backend_by_env_var.endswith("_VLLM_V1"):
                logger.warning(
                    "The suffix '_VLLM_V1' in the environment variable "
                    "%s is no longer necessary as V0 backends have been "
                    "deprecated. Please remove this suffix from your "
144
145
146
147
                    "environment variable setting.",
                    STR_BACKEND_ENV_VAR,
                )
                backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
148
149
150
            try:
                selected_backend = AttentionBackendEnum[backend_by_env_var]
            except KeyError as e:
151
                raise ValueError(
152
153
154
                    f"Invalid attention backend: '{backend_by_env_var}'. Valid "
                    f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
                ) from e
155

156
    # get device-specific attn_backend
157
158
    from vllm.platforms import current_platform

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    sig = inspect.signature(current_platform.get_attn_backend_cls)
    if "use_v1" in sig.parameters:
        logger.warning_once(
            "use_v1 parameter for get_attn_backend_cls is deprecated and will "
            "be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
            "remove it from your plugin code."
        )
        attention_cls = current_platform.get_attn_backend_cls(
            selected_backend,
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            True,  # use_v1
            use_mla,
            has_sink,
            use_sparse,
176
            attn_type,
177
178
179
180
181
182
183
184
185
186
187
        )
    else:
        attention_cls = current_platform.get_attn_backend_cls(
            selected_backend,
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
188
            attn_type,
189
        )
190
191
    if not attention_cls:
        raise ValueError(
192
193
            f"Invalid attention backend for {current_platform.device_name}"
        )
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
209
210


211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    """Select which mamba attention backend to use and lazily import it."""
    return _cached_get_mamba_attn_backend(mamba_type)


@cache
def _cached_get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    assert mamba_type and isinstance(mamba_type, str)

    selected_backend = None
    try:
        backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
        selected_backend = MambaAttentionBackendEnum[backend_name]
    except KeyError as e:
        raise ValueError(
            f"Invalid mamba attention backend type: '{backend_name}'. Valid "
            f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
        ) from e

    mamba_attn_backend = selected_backend.get_class()
    return mamba_attn_backend


238
239
@contextmanager
def global_force_attn_backend_context_manager(
240
    attn_backend: AttentionBackendEnum,
241
242
) -> Generator[None, None, None]:
    """
243
244
245
246
247
248
249
250
251
252
253
254
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
255
    """
256
257
258
259
260
261
262
263
264
265
266
267
268

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)
269
        _cached_get_attn_backend.cache_clear()