selector.py 6.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections.abc import Generator
6
from contextlib import contextmanager
7
from functools import cache
8
from typing import cast, get_args
9
10
11

import torch

12
import vllm.envs as envs
13
from vllm.attention.backends.abstract import AttentionBackend
14
15
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.cache import CacheDType
16
from vllm.logger import init_logger
17
18
from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils.import_utils import resolve_obj_by_qualname
19
20
21
22

logger = init_logger(__name__)


23
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
24
    """
25
26
27
28
29
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

30
    * AttentionBackendEnum value if an override is specified
31
    * None otherwise
32
    """
33
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
34
    return None if backend_name is None else AttentionBackendEnum[backend_name]
35
36
37
38
39
40
41
42


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
43
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
44
forced_attn_backend: AttentionBackendEnum | None = None
45
46


47
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
48
    """
49
50
51
52
53
54
55
56
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
57
    """
58
59
60
61
    global forced_attn_backend
    forced_attn_backend = attn_backend


62
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
63
    """
64
65
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
66
    """
67
68
69
    return forced_attn_backend


70
71
72
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
73
    kv_cache_dtype: str | None,
74
    block_size: int | None,
75
    use_mla: bool = False,
76
    has_sink: bool = False,
77
    use_sparse: bool = False,
78
) -> type[AttentionBackend]:
79
    """Selects which attention backend to use and lazily imports it."""
80
81
82
83
84
85
86
87

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

Joe Runde's avatar
Joe Runde committed
88
89
90
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
91
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Joe Runde's avatar
Joe Runde committed
92
        block_size=block_size,
93
        use_mla=use_mla,
94
        has_sink=has_sink,
95
        use_sparse=use_sparse,
Joe Runde's avatar
Joe Runde committed
96
97
98
    )


99
@cache
Joe Runde's avatar
Joe Runde committed
100
101
102
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
103
104
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
105
    use_mla: bool = False,
106
    has_sink: bool = False,
107
    use_sparse: bool = False,
108
) -> type[AttentionBackend]:
109
110
111
112
113
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
114
    selected_backend = None
115
116
117
    backend_by_global_setting: AttentionBackendEnum | None = (
        get_global_forced_attn_backend()
    )
118
119
120
121
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
122
        backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
123
        if backend_by_env_var is not None:
124
125
126
127
128
            if backend_by_env_var.endswith("_VLLM_V1"):
                logger.warning(
                    "The suffix '_VLLM_V1' in the environment variable "
                    "%s is no longer necessary as V0 backends have been "
                    "deprecated. Please remove this suffix from your "
129
130
131
132
                    "environment variable setting.",
                    STR_BACKEND_ENV_VAR,
                )
                backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
133
134
135
            try:
                selected_backend = AttentionBackendEnum[backend_by_env_var]
            except KeyError as e:
136
                raise ValueError(
137
138
139
                    f"Invalid attention backend: '{backend_by_env_var}'. Valid "
                    f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
                ) from e
140

141
    # get device-specific attn_backend
142
143
    from vllm.platforms import current_platform

144
    attention_cls = current_platform.get_attn_backend_cls(
145
146
147
148
149
        selected_backend,
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
150
        True,
151
152
153
154
        use_mla,
        has_sink,
        use_sparse,
    )
155
156
    if not attention_cls:
        raise ValueError(
157
158
            f"Invalid attention backend for {current_platform.device_name}"
        )
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    backend = resolve_obj_by_qualname(attention_cls)

    # Adjust kv cache layout if the selected backend requires a specific one
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )

    return backend
174
175
176
177


@contextmanager
def global_force_attn_backend_context_manager(
178
    attn_backend: AttentionBackendEnum,
179
180
) -> Generator[None, None, None]:
    """
181
182
183
184
185
186
187
188
189
190
191
192
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
193
    """
194
195
196
197
198
199
200
201
202
203
204
205
206

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)
207
        _cached_get_attn_backend.cache_clear()