selector.py 7.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections.abc import Generator
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from functools import cache
9
from typing import Optional, Union
10
11
12

import torch

13
import vllm.envs as envs
14
from vllm.attention.backends.abstract import AttentionBackend
15
from vllm.attention.backends.registry import _Backend
16
from vllm.logger import init_logger
17
from vllm.platforms import current_platform
18
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
19
20
21
22

logger = init_logger(__name__)


23
24
25
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
    """
    Convert a string backend name to a _Backend enum value.
26

27
28
29
30
31
32
    Returns:
    * _Backend: enum value if backend_name is a valid in-tree type
    * None: otherwise it's an invalid in-tree type or an out-of-tree platform is
            loaded.
    """
    assert backend_name is not None
33
    return _Backend[backend_name] if backend_name in _Backend.__members__ else None
34
35
36


def get_env_variable_attn_backend() -> Optional[_Backend]:
37
    """
38
39
40
41
42
43
44
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

    * _Backend enum value if an override is specified
    * None otherwise
45
    """
46
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
47
    return None if backend_name is None else backend_name_to_enum(backend_name)
48
49
50
51
52
53
54
55


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
56
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
57
58
59
60
forced_attn_backend: Optional[_Backend] = None


def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
61
    """
62
63
64
65
66
67
68
69
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
70
    """
71
72
73
74
75
    global forced_attn_backend
    forced_attn_backend = attn_backend


def get_global_forced_attn_backend() -> Optional[_Backend]:
76
    """
77
78
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
79
    """
80
81
82
    return forced_attn_backend


83
84
85
86
87
88
89
90
91
92
93
@dataclass(frozen=True)
class _IsSupported:
    can_import: bool
    head_size: bool
    dtype: bool

    def __bool__(self) -> bool:
        return self.can_import and self.head_size and self.dtype


def is_attn_backend_supported(
94
95
    attn_backend: Union[str, type[AttentionBackend]],
    head_size: int,
96
97
98
99
    dtype: torch.dtype,
    *,
    allow_import_error: bool = True,
) -> _IsSupported:
100
101
102
103
    if isinstance(attn_backend, str):
        try:
            attn_backend = resolve_obj_by_qualname(attn_backend)
        except ImportError:
104
105
106
107
            if not allow_import_error:
                raise

            return _IsSupported(can_import=False, head_size=False, dtype=False)
108
109
110
111

    assert isinstance(attn_backend, type)

    # TODO: Update the interface once V0 is removed
112
113
114
    if get_supported_head_sizes := getattr(
        attn_backend, "get_supported_head_sizes", None
    ):
115
        is_head_size_supported = head_size in get_supported_head_sizes()
116
    elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
117
118
        try:
            validate_head_size(head_size)
119
            is_head_size_supported = True
120
        except Exception:
121
122
            is_head_size_supported = False
    else:
123
124
125
        raise NotImplementedError(
            f"{attn_backend.__name__} does not support head size validation"
        )
126

127
    if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
128
129
        is_dtype_supported = dtype in get_supported_dtypes()
    else:
130
131
132
        raise NotImplementedError(
            f"{attn_backend.__name__} does not support dtype validation"
        )
133
134
135
136
137
138

    return _IsSupported(
        can_import=True,
        head_size=is_head_size_supported,
        dtype=is_dtype_supported,
    )
139
140


141
142
143
144
145
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
146
    use_mla: bool = False,
147
    has_sink: bool = False,
148
    use_sparse: bool = False,
149
) -> type[AttentionBackend]:
150
    """Selects which attention backend to use and lazily imports it."""
Joe Runde's avatar
Joe Runde committed
151
152
153
154
155
156
157
158
159
160
    # Accessing envs.* behind an @lru_cache decorator can cause the wrong
    # value to be returned from the cache if the value changes between calls.
    # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
    # private function.
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
        kv_cache_dtype=kv_cache_dtype,
        block_size=block_size,
        use_v1=envs.VLLM_USE_V1,
161
        use_mla=use_mla,
162
        has_sink=has_sink,
163
        use_sparse=use_sparse,
Joe Runde's avatar
Joe Runde committed
164
165
166
    )


167
@cache
Joe Runde's avatar
Joe Runde committed
168
169
170
171
172
173
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
    use_v1: bool = False,
174
    use_mla: bool = False,
175
    has_sink: bool = False,
176
    use_sparse: bool = False,
177
) -> type[AttentionBackend]:
178
179
180
181
182
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
183
    selected_backend = None
184
    backend_by_global_setting: Optional[_Backend] = get_global_forced_attn_backend()
185
186
187
188
189
190
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
191
192
193
194
195
            if backend_by_env_var.endswith("_VLLM_V1"):
                logger.warning(
                    "The suffix '_VLLM_V1' in the environment variable "
                    "%s is no longer necessary as V0 backends have been "
                    "deprecated. Please remove this suffix from your "
196
197
198
199
                    "environment variable setting.",
                    STR_BACKEND_ENV_VAR,
                )
                backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
200
            selected_backend = backend_name_to_enum(backend_by_env_var)
201
202
203
            if selected_backend is None:
                raise ValueError(
                    f"Invalid attention backend: '{backend_by_env_var}'. "
204
205
                    f"Valid backends are: {list(_Backend.__members__.keys())}"
                )
206

207
208
    # get device-specific attn_backend
    attention_cls = current_platform.get_attn_backend_cls(
209
210
211
212
213
214
215
216
217
218
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_v1,
        use_mla,
        has_sink,
        use_sparse,
    )
219
220
    if not attention_cls:
        raise ValueError(
221
222
            f"Invalid attention backend for {current_platform.device_name}"
        )
223
    return resolve_obj_by_qualname(attention_cls)
224
225
226
227


@contextmanager
def global_force_attn_backend_context_manager(
228
229
230
    attn_backend: _Backend,
) -> Generator[None, None, None]:
    """
231
232
233
234
235
236
237
238
239
240
241
242
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
243
    """
244
245
246
247
248
249
250
251
252
253
254
255
256

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)