selector.py 7.88 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import os
from contextlib import contextmanager
6
from dataclasses import dataclass
7
from functools import cache
8
from typing import Generator, Optional, Union
9
10
11

import torch

12
import vllm.envs as envs
13
from vllm.attention.backends.abstract import AttentionBackend
14
from vllm.attention.backends.registry import _Backend
15
from vllm.logger import init_logger
16
from vllm.platforms import current_platform
17
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
18
19
20
21

logger = init_logger(__name__)


22
23
24
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
    """
    Convert a string backend name to a _Backend enum value.
25

26
27
28
29
30
31
32
33
    Returns:
    * _Backend: enum value if backend_name is a valid in-tree type
    * None: otherwise it's an invalid in-tree type or an out-of-tree platform is
            loaded.
    """
    assert backend_name is not None
    return _Backend[backend_name] if backend_name in _Backend.__members__ else \
          None
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56


def get_env_variable_attn_backend() -> Optional[_Backend]:
    '''
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

    * _Backend enum value if an override is specified
    * None otherwise
    '''
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
    return (None
            if backend_name is None else backend_name_to_enum(backend_name))


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
57
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
forced_attn_backend: Optional[_Backend] = None


def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
    '''
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
    '''
    global forced_attn_backend
    forced_attn_backend = attn_backend


def get_global_forced_attn_backend() -> Optional[_Backend]:
    '''
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
    '''
    return forced_attn_backend


84
85
86
87
88
89
90
91
92
93
94
@dataclass(frozen=True)
class _IsSupported:
    can_import: bool
    head_size: bool
    dtype: bool

    def __bool__(self) -> bool:
        return self.can_import and self.head_size and self.dtype


def is_attn_backend_supported(
95
96
    attn_backend: Union[str, type[AttentionBackend]],
    head_size: int,
97
98
99
100
    dtype: torch.dtype,
    *,
    allow_import_error: bool = True,
) -> _IsSupported:
101
102
103
104
    if isinstance(attn_backend, str):
        try:
            attn_backend = resolve_obj_by_qualname(attn_backend)
        except ImportError:
105
106
107
108
            if not allow_import_error:
                raise

            return _IsSupported(can_import=False, head_size=False, dtype=False)
109
110
111
112
113
114

    assert isinstance(attn_backend, type)

    # TODO: Update the interface once V0 is removed
    if get_supported_head_sizes := getattr(attn_backend,
                                           "get_supported_head_sizes", None):
115
116
117
        is_head_size_supported = head_size in get_supported_head_sizes()
    elif validate_head_size := getattr(attn_backend, "validate_head_size",
                                       None):
118
119
        try:
            validate_head_size(head_size)
120
            is_head_size_supported = True
121
        except Exception:
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            is_head_size_supported = False
    else:
        raise NotImplementedError(f"{attn_backend.__name__} does not support "
                                  "head size validation")

    if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
                                       None):
        is_dtype_supported = dtype in get_supported_dtypes()
    else:
        raise NotImplementedError(f"{attn_backend.__name__} does not support "
                                  "dtype validation")

    return _IsSupported(
        can_import=True,
        head_size=is_head_size_supported,
        dtype=is_dtype_supported,
    )
139
140


141
142
143
144
145
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
146
    use_mla: bool = False,
147
    has_sink: bool = False,
148
    use_sparse: bool = False,
149
) -> type[AttentionBackend]:
150
    """Selects which attention backend to use and lazily imports it."""
Joe Runde's avatar
Joe Runde committed
151
152
153
154
155
156
157
158
159
160
    # Accessing envs.* behind an @lru_cache decorator can cause the wrong
    # value to be returned from the cache if the value changes between calls.
    # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
    # private function.
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
        kv_cache_dtype=kv_cache_dtype,
        block_size=block_size,
        use_v1=envs.VLLM_USE_V1,
161
        use_mla=use_mla,
162
        has_sink=has_sink,
163
        use_sparse=use_sparse,
Joe Runde's avatar
Joe Runde committed
164
165
166
    )


167
@cache
Joe Runde's avatar
Joe Runde committed
168
169
170
171
172
173
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
    use_v1: bool = False,
174
    use_mla: bool = False,
175
    has_sink: bool = False,
176
    use_sparse: bool = False,
177
) -> type[AttentionBackend]:
178

179
180
181
182
183
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
184
    selected_backend = None
185
186
187
188
189
190
191
192
    backend_by_global_setting: Optional[_Backend] = (
        get_global_forced_attn_backend())
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
193
194
195
196
197
198
199
200
            if backend_by_env_var.endswith("_VLLM_V1"):
                logger.warning(
                    "The suffix '_VLLM_V1' in the environment variable "
                    "%s is no longer necessary as V0 backends have been "
                    "deprecated. Please remove this suffix from your "
                    "environment variable setting.", STR_BACKEND_ENV_VAR)
                backend_by_env_var = backend_by_env_var.removesuffix(
                    "_VLLM_V1")
201
            selected_backend = backend_name_to_enum(backend_by_env_var)
202
203
204
205
            if selected_backend is None:
                raise ValueError(
                    f"Invalid attention backend: '{backend_by_env_var}'. "
                    f"Valid backends are: {list(_Backend.__members__.keys())}")
206

207
208
    # get device-specific attn_backend
    attention_cls = current_platform.get_attn_backend_cls(
209
        selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
210
        use_mla, has_sink, use_sparse)
211
212
213
214
    if not attention_cls:
        raise ValueError(
            f"Invalid attention backend for {current_platform.device_name}")
    return resolve_obj_by_qualname(attention_cls)
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246


@contextmanager
def global_force_attn_backend_context_manager(
        attn_backend: _Backend) -> Generator[None, None, None]:
    '''
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
    '''

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)