selector.py 8.04 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import inspect
5
import os
6
from collections.abc import Generator
7
from contextlib import contextmanager
8
from functools import cache
9
from typing import cast, get_args
10
11
12

import torch

13
import vllm.envs as envs
14
from vllm.attention.backends.abstract import AttentionBackend
15
16
17
18
19
from vllm.attention.backends.registry import (
    MAMBA_TYPE_TO_BACKEND_MAP,
    AttentionBackendEnum,
    MambaAttentionBackendEnum,
)
20
from vllm.config.cache import CacheDType
21
from vllm.logger import init_logger
22
23
from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils.import_utils import resolve_obj_by_qualname
24
25
26
27

logger = init_logger(__name__)


28
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
29
    """
30
31
32
33
34
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

35
    * AttentionBackendEnum value if an override is specified
36
    * None otherwise
37
    """
38
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
39
    return None if backend_name is None else AttentionBackendEnum[backend_name]
40
41
42
43
44
45
46
47


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
48
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
49
forced_attn_backend: AttentionBackendEnum | None = None
50
51


52
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
53
    """
54
55
56
57
58
59
60
61
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
62
    """
63
64
65
66
    global forced_attn_backend
    forced_attn_backend = attn_backend


67
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
68
    """
69
70
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
71
    """
72
73
74
    return forced_attn_backend


75
76
77
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
78
    kv_cache_dtype: str | None,
79
    block_size: int | None,
80
    use_mla: bool = False,
81
    has_sink: bool = False,
82
    use_sparse: bool = False,
83
    attn_type: str | None = None,
84
) -> type[AttentionBackend]:
85
    """Selects which attention backend to use and lazily imports it."""
86
87
88
89
90
91
92
93

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

Joe Runde's avatar
Joe Runde committed
94
95
96
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
97
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
98
        block_size=block_size,
99
        use_mla=use_mla,
100
        has_sink=has_sink,
101
        use_sparse=use_sparse,
102
        attn_type=attn_type,
Joe Runde's avatar
Joe Runde committed
103
104
105
    )


106
@cache
Joe Runde's avatar
Joe Runde committed
107
108
109
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
110
111
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
112
    use_mla: bool = False,
113
    has_sink: bool = False,
114
    use_sparse: bool = False,
115
    attn_type: str | None = None,
116
) -> type[AttentionBackend]:
117
118
119
120
121
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
122
    selected_backend = None
123
124
125
    backend_by_global_setting: AttentionBackendEnum | None = (
        get_global_forced_attn_backend()
    )
126
127
128
129
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
130
        backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
131
        if backend_by_env_var is not None:
132
133
134
135
136
            if backend_by_env_var.endswith("_VLLM_V1"):
                logger.warning(
                    "The suffix '_VLLM_V1' in the environment variable "
                    "%s is no longer necessary as V0 backends have been "
                    "deprecated. Please remove this suffix from your "
137
138
139
140
                    "environment variable setting.",
                    STR_BACKEND_ENV_VAR,
                )
                backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
141
142
143
            try:
                selected_backend = AttentionBackendEnum[backend_by_env_var]
            except KeyError as e:
144
                raise ValueError(
145
146
147
                    f"Invalid attention backend: '{backend_by_env_var}'. Valid "
                    f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
                ) from e
148

149
    # get device-specific attn_backend
150
151
    from vllm.platforms import current_platform

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    sig = inspect.signature(current_platform.get_attn_backend_cls)
    if "use_v1" in sig.parameters:
        logger.warning_once(
            "use_v1 parameter for get_attn_backend_cls is deprecated and will "
            "be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
            "remove it from your plugin code."
        )
        attention_cls = current_platform.get_attn_backend_cls(
            selected_backend,
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            True,  # use_v1
            use_mla,
            has_sink,
            use_sparse,
169
            attn_type,
170
171
172
173
174
175
176
177
178
179
180
        )
    else:
        attention_cls = current_platform.get_attn_backend_cls(
            selected_backend,
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
181
            attn_type,
182
        )
183
184
    if not attention_cls:
        raise ValueError(
185
186
            f"Invalid attention backend for {current_platform.device_name}"
        )
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
202
203


204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    """Select which mamba attention backend to use and lazily import it."""
    return _cached_get_mamba_attn_backend(mamba_type)


@cache
def _cached_get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    assert mamba_type and isinstance(mamba_type, str)

    selected_backend = None
    try:
        backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
        selected_backend = MambaAttentionBackendEnum[backend_name]
    except KeyError as e:
        raise ValueError(
            f"Invalid mamba attention backend type: '{backend_name}'. Valid "
            f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
        ) from e

    mamba_attn_backend = selected_backend.get_class()
    return mamba_attn_backend


231
232
@contextmanager
def global_force_attn_backend_context_manager(
233
    attn_backend: AttentionBackendEnum,
234
235
) -> Generator[None, None, None]:
    """
236
237
238
239
240
241
242
243
244
245
246
247
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
248
    """
249
250
251
252
253
254
255
256
257
258
259
260
261

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)
262
        _cached_get_attn_backend.cache_clear()