selector.py 7.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections.abc import Generator
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from functools import cache
9
from typing import Optional, Union
10
11
12

import torch

13
import vllm.envs as envs
14
from vllm.attention.backends.abstract import AttentionBackend
15
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
16
from vllm.logger import init_logger
17
from vllm.platforms import current_platform
18
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
19
20
21
22

logger = init_logger(__name__)


23
def get_env_variable_attn_backend() -> Optional[_Backend]:
24
    """
25
26
27
28
29
30
31
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

    * _Backend enum value if an override is specified
    * None otherwise
32
    """
33
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
34
    return None if backend_name is None else backend_name_to_enum(backend_name)
35
36
37
38
39
40
41
42


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
43
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
44
45
46
47
forced_attn_backend: Optional[_Backend] = None


def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
48
    """
49
50
51
52
53
54
55
56
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
57
    """
58
59
60
61
62
    global forced_attn_backend
    forced_attn_backend = attn_backend


def get_global_forced_attn_backend() -> Optional[_Backend]:
63
    """
64
65
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
66
    """
67
68
69
    return forced_attn_backend


70
71
72
73
74
75
76
77
78
79
80
@dataclass(frozen=True)
class _IsSupported:
    can_import: bool
    head_size: bool
    dtype: bool

    def __bool__(self) -> bool:
        return self.can_import and self.head_size and self.dtype


def is_attn_backend_supported(
81
82
    attn_backend: Union[str, type[AttentionBackend]],
    head_size: int,
83
84
85
86
    dtype: torch.dtype,
    *,
    allow_import_error: bool = True,
) -> _IsSupported:
87
88
89
90
    if isinstance(attn_backend, str):
        try:
            attn_backend = resolve_obj_by_qualname(attn_backend)
        except ImportError:
91
92
93
94
            if not allow_import_error:
                raise

            return _IsSupported(can_import=False, head_size=False, dtype=False)
95
96
97
98

    assert isinstance(attn_backend, type)

    # TODO: Update the interface once V0 is removed
99
100
101
    if get_supported_head_sizes := getattr(
        attn_backend, "get_supported_head_sizes", None
    ):
102
        is_head_size_supported = head_size in get_supported_head_sizes()
103
    elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
104
105
        try:
            validate_head_size(head_size)
106
            is_head_size_supported = True
107
        except Exception:
108
109
            is_head_size_supported = False
    else:
110
111
112
        raise NotImplementedError(
            f"{attn_backend.__name__} does not support head size validation"
        )
113

114
    if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
115
116
        is_dtype_supported = dtype in get_supported_dtypes()
    else:
117
118
119
        raise NotImplementedError(
            f"{attn_backend.__name__} does not support dtype validation"
        )
120
121
122
123
124
125

    return _IsSupported(
        can_import=True,
        head_size=is_head_size_supported,
        dtype=is_dtype_supported,
    )
126
127


128
129
130
131
132
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
133
    use_mla: bool = False,
134
    has_sink: bool = False,
135
    use_sparse: bool = False,
136
) -> type[AttentionBackend]:
137
    """Selects which attention backend to use and lazily imports it."""
Joe Runde's avatar
Joe Runde committed
138
139
140
141
142
143
144
145
146
147
    # Accessing envs.* behind an @lru_cache decorator can cause the wrong
    # value to be returned from the cache if the value changes between calls.
    # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
    # private function.
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
        kv_cache_dtype=kv_cache_dtype,
        block_size=block_size,
        use_v1=envs.VLLM_USE_V1,
148
        use_mla=use_mla,
149
        has_sink=has_sink,
150
        use_sparse=use_sparse,
Joe Runde's avatar
Joe Runde committed
151
152
153
    )


154
@cache
Joe Runde's avatar
Joe Runde committed
155
156
157
158
159
160
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
    use_v1: bool = False,
161
    use_mla: bool = False,
162
    has_sink: bool = False,
163
    use_sparse: bool = False,
164
) -> type[AttentionBackend]:
165
166
167
168
169
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
170
    selected_backend = None
171
    backend_by_global_setting: Optional[_Backend] = get_global_forced_attn_backend()
172
173
174
175
176
177
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
178
179
180
181
182
            if backend_by_env_var.endswith("_VLLM_V1"):
                logger.warning(
                    "The suffix '_VLLM_V1' in the environment variable "
                    "%s is no longer necessary as V0 backends have been "
                    "deprecated. Please remove this suffix from your "
183
184
185
186
                    "environment variable setting.",
                    STR_BACKEND_ENV_VAR,
                )
                backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
187
            selected_backend = backend_name_to_enum(backend_by_env_var)
188
189
190
            if selected_backend is None:
                raise ValueError(
                    f"Invalid attention backend: '{backend_by_env_var}'. "
191
192
                    f"Valid backends are: {list(_Backend.__members__.keys())}"
                )
193

194
195
    # get device-specific attn_backend
    attention_cls = current_platform.get_attn_backend_cls(
196
197
198
199
200
201
202
203
204
205
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_v1,
        use_mla,
        has_sink,
        use_sparse,
    )
206
207
    if not attention_cls:
        raise ValueError(
208
209
            f"Invalid attention backend for {current_platform.device_name}"
        )
210
    return resolve_obj_by_qualname(attention_cls)
211
212
213
214


@contextmanager
def global_force_attn_backend_context_manager(
215
216
217
    attn_backend: _Backend,
) -> Generator[None, None, None]:
    """
218
219
220
221
222
223
224
225
226
227
228
229
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
230
    """
231
232
233
234
235
236
237
238
239
240
241
242
243

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)
244
        _cached_get_attn_backend.cache_clear()