selector.py 8.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import inspect
5
import os
6
from collections.abc import Generator
7
from contextlib import contextmanager
8
from functools import cache
9
from typing import cast, get_args
10
11
12

import torch

13
import vllm.envs as envs
14
from vllm.attention.backends.abstract import AttentionBackend
15
16
17
18
19
from vllm.attention.backends.registry import (
    MAMBA_TYPE_TO_BACKEND_MAP,
    AttentionBackendEnum,
    MambaAttentionBackendEnum,
)
20
from vllm.config.cache import CacheDType
21
from vllm.logger import init_logger
22
from vllm.utils.import_utils import resolve_obj_by_qualname
23
24
25
26

logger = init_logger(__name__)


27
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
28
    """
29
30
31
32
33
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

34
    * AttentionBackendEnum value if an override is specified
35
    * None otherwise
36
    """
37
    backend_name = os.environ.get("VLLM_ATTENTION_BACKEND")
38
39
40
41
42
43
44
45
    if backend_name is None:
        return None
    if backend_name == "XFORMERS":
        raise ValueError(
            "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
            "details). Please select a supported attention backend."
        )
    return AttentionBackendEnum[backend_name]
46
47
48
49
50
51
52
53


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
54
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
55
forced_attn_backend: AttentionBackendEnum | None = None
56
57


58
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
59
    """
60
61
62
63
64
65
66
67
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
68
    """
69
70
71
72
    global forced_attn_backend
    forced_attn_backend = attn_backend


73
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
74
    """
75
76
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
77
    """
78
79
80
    return forced_attn_backend


81
82
83
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
84
    kv_cache_dtype: str | None,
85
    block_size: int | None,
86
    use_mla: bool = False,
87
    has_sink: bool = False,
88
    use_sparse: bool = False,
89
    attn_type: str | None = None,
90
) -> type[AttentionBackend]:
91
    """Selects which attention backend to use and lazily imports it."""
92
93
94
95
96
97
98
99

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

Joe Runde's avatar
Joe Runde committed
100
101
102
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
103
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
104
        block_size=block_size,
105
        use_mla=use_mla,
106
        has_sink=has_sink,
107
        use_sparse=use_sparse,
108
        attn_type=attn_type,
Joe Runde's avatar
Joe Runde committed
109
110
111
    )


112
@cache
Joe Runde's avatar
Joe Runde committed
113
114
115
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
116
117
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
118
    use_mla: bool = False,
119
    has_sink: bool = False,
120
    use_sparse: bool = False,
121
    attn_type: str | None = None,
122
) -> type[AttentionBackend]:
123
124
125
126
127
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
128
    selected_backend = None
129
130
131
    backend_by_global_setting: AttentionBackendEnum | None = (
        get_global_forced_attn_backend()
    )
132
133
134
135
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
136
        backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
137
        if backend_by_env_var is not None:
138
139
140
            if backend_by_env_var.endswith("_VLLM_V1"):
                logger.warning(
                    "The suffix '_VLLM_V1' in the environment variable "
141
142
143
                    "VLLM_ATTENTION_BACKEND is no longer necessary as "
                    "V0 backends have been deprecated. "
                    "Please remove this suffix from your "
144
145
146
                    "environment variable setting.",
                )
                backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
147
148
149
            try:
                selected_backend = AttentionBackendEnum[backend_by_env_var]
            except KeyError as e:
150
                raise ValueError(
151
152
153
                    f"Invalid attention backend: '{backend_by_env_var}'. Valid "
                    f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
                ) from e
154

155
    # get device-specific attn_backend
156
157
    from vllm.platforms import current_platform

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    sig = inspect.signature(current_platform.get_attn_backend_cls)
    if "use_v1" in sig.parameters:
        logger.warning_once(
            "use_v1 parameter for get_attn_backend_cls is deprecated and will "
            "be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
            "remove it from your plugin code."
        )
        attention_cls = current_platform.get_attn_backend_cls(
            selected_backend,
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            True,  # use_v1
            use_mla,
            has_sink,
            use_sparse,
175
            attn_type,
176
177
178
179
180
181
182
183
184
185
186
        )
    else:
        attention_cls = current_platform.get_attn_backend_cls(
            selected_backend,
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
187
            attn_type,
188
        )
189
190
    if not attention_cls:
        raise ValueError(
191
192
            f"Invalid attention backend for {current_platform.device_name}"
        )
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
208
209


210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    """Select which mamba attention backend to use and lazily import it."""
    return _cached_get_mamba_attn_backend(mamba_type)


@cache
def _cached_get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    assert mamba_type and isinstance(mamba_type, str)

    selected_backend = None
    try:
        backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
        selected_backend = MambaAttentionBackendEnum[backend_name]
    except KeyError as e:
        raise ValueError(
            f"Invalid mamba attention backend type: '{backend_name}'. Valid "
            f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
        ) from e

    mamba_attn_backend = selected_backend.get_class()
    return mamba_attn_backend


237
238
@contextmanager
def global_force_attn_backend_context_manager(
239
    attn_backend: AttentionBackendEnum,
240
241
) -> Generator[None, None, None]:
    """
242
243
244
245
246
247
248
249
250
251
252
253
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
254
    """
255
256
257
258
259
260
261
262
263
264
265
266
267

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)
268
        _cached_get_attn_backend.cache_clear()