"vllm/v1/sample/ops/utils.py" did not exist on "3a1e6481586ed7f079275b5d5072a6e246af691e"
selector.py 7.41 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections.abc import Generator
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from functools import cache
9
10
11

import torch

12
import vllm.envs as envs
13
from vllm.attention.backends.abstract import AttentionBackend
14
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
15
from vllm.logger import init_logger
16
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
17
18
19
20

logger = init_logger(__name__)


21
def get_env_variable_attn_backend() -> _Backend | None:
22
    """
23
24
25
26
27
28
29
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

    * _Backend enum value if an override is specified
    * None otherwise
30
    """
31
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
32
    return None if backend_name is None else backend_name_to_enum(backend_name)
33
34
35
36
37
38
39
40


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
41
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
42
forced_attn_backend: _Backend | None = None
43
44


45
def global_force_attn_backend(attn_backend: _Backend | None) -> None:
46
    """
47
48
49
50
51
52
53
54
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
55
    """
56
57
58
59
    global forced_attn_backend
    forced_attn_backend = attn_backend


60
def get_global_forced_attn_backend() -> _Backend | None:
61
    """
62
63
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
64
    """
65
66
67
    return forced_attn_backend


68
69
70
71
72
73
74
75
76
77
78
@dataclass(frozen=True)
class _IsSupported:
    can_import: bool
    head_size: bool
    dtype: bool

    def __bool__(self) -> bool:
        return self.can_import and self.head_size and self.dtype


def is_attn_backend_supported(
79
    attn_backend: str | type[AttentionBackend],
80
    head_size: int,
81
82
83
84
    dtype: torch.dtype,
    *,
    allow_import_error: bool = True,
) -> _IsSupported:
85
86
87
88
    if isinstance(attn_backend, str):
        try:
            attn_backend = resolve_obj_by_qualname(attn_backend)
        except ImportError:
89
90
91
92
            if not allow_import_error:
                raise

            return _IsSupported(can_import=False, head_size=False, dtype=False)
93
94
95
96

    assert isinstance(attn_backend, type)

    # TODO: Update the interface once V0 is removed
97
98
99
    if get_supported_head_sizes := getattr(
        attn_backend, "get_supported_head_sizes", None
    ):
100
        is_head_size_supported = head_size in get_supported_head_sizes()
101
    elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
102
103
        try:
            validate_head_size(head_size)
104
            is_head_size_supported = True
105
        except Exception:
106
107
            is_head_size_supported = False
    else:
108
109
110
        raise NotImplementedError(
            f"{attn_backend.__name__} does not support head size validation"
        )
111

112
    if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
113
114
        is_dtype_supported = dtype in get_supported_dtypes()
    else:
115
116
117
        raise NotImplementedError(
            f"{attn_backend.__name__} does not support dtype validation"
        )
118
119
120
121
122
123

    return _IsSupported(
        can_import=True,
        head_size=is_head_size_supported,
        dtype=is_dtype_supported,
    )
124
125


126
127
128
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
129
    kv_cache_dtype: str | None,
130
    block_size: int,
131
    use_mla: bool = False,
132
    has_sink: bool = False,
133
    use_sparse: bool = False,
134
) -> type[AttentionBackend]:
135
    """Selects which attention backend to use and lazily imports it."""
Joe Runde's avatar
Joe Runde committed
136
137
138
139
140
141
142
143
144
145
    # Accessing envs.* behind an @lru_cache decorator can cause the wrong
    # value to be returned from the cache if the value changes between calls.
    # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
    # private function.
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
        kv_cache_dtype=kv_cache_dtype,
        block_size=block_size,
        use_v1=envs.VLLM_USE_V1,
146
        use_mla=use_mla,
147
        has_sink=has_sink,
148
        use_sparse=use_sparse,
Joe Runde's avatar
Joe Runde committed
149
150
151
    )


152
@cache
Joe Runde's avatar
Joe Runde committed
153
154
155
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
156
    kv_cache_dtype: str | None,
Joe Runde's avatar
Joe Runde committed
157
158
    block_size: int,
    use_v1: bool = False,
159
    use_mla: bool = False,
160
    has_sink: bool = False,
161
    use_sparse: bool = False,
162
) -> type[AttentionBackend]:
163
164
165
166
167
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
168
    selected_backend = None
169
    backend_by_global_setting: _Backend | None = get_global_forced_attn_backend()
170
171
172
173
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
174
        backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
175
        if backend_by_env_var is not None:
176
177
178
179
180
            if backend_by_env_var.endswith("_VLLM_V1"):
                logger.warning(
                    "The suffix '_VLLM_V1' in the environment variable "
                    "%s is no longer necessary as V0 backends have been "
                    "deprecated. Please remove this suffix from your "
181
182
183
184
                    "environment variable setting.",
                    STR_BACKEND_ENV_VAR,
                )
                backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
185
            selected_backend = backend_name_to_enum(backend_by_env_var)
186
187
188
            if selected_backend is None:
                raise ValueError(
                    f"Invalid attention backend: '{backend_by_env_var}'. "
189
190
                    f"Valid backends are: {list(_Backend.__members__.keys())}"
                )
191

192
    # get device-specific attn_backend
193
194
    from vllm.platforms import current_platform

195
    attention_cls = current_platform.get_attn_backend_cls(
196
197
198
199
200
201
202
203
204
205
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_v1,
        use_mla,
        has_sink,
        use_sparse,
    )
206
207
    if not attention_cls:
        raise ValueError(
208
209
            f"Invalid attention backend for {current_platform.device_name}"
        )
210
    return resolve_obj_by_qualname(attention_cls)
211
212
213
214


@contextmanager
def global_force_attn_backend_context_manager(
215
216
217
    attn_backend: _Backend,
) -> Generator[None, None, None]:
    """
218
219
220
221
222
223
224
225
226
227
228
229
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
230
    """
231
232
233
234
235
236
237
238
239
240
241
242
243

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)
244
        _cached_get_attn_backend.cache_clear()