selector.py 10.6 KB
Newer Older
1
2
import os
from contextlib import contextmanager
3
from functools import lru_cache
4
from typing import Generator, Optional, Type
5
6
7

import torch

8
import vllm.envs as envs
9
10
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
11
from vllm.platforms import _Backend, current_platform
12
from vllm.utils import STR_BACKEND_ENV_VAR
13
14
15
16

logger = init_logger(__name__)


17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def backend_name_to_enum(backend_name: str) -> _Backend:
    assert backend_name is not None

    backend_members = _Backend.__members__
    if backend_name not in backend_members:
        raise ValueError(f"Invalid attention backend '{backend_name}'. "
                         f"Available backends: {', '.join(backend_members)} "
                         "(case-sensitive).")

    return _Backend[backend_name]


def get_env_variable_attn_backend() -> Optional[_Backend]:
    '''
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

    * _Backend enum value if an override is specified
    * None otherwise
    '''
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
    return (None
            if backend_name is None else backend_name_to_enum(backend_name))


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: Optional[_Backend] = None


def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
    '''
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
    '''
    global forced_attn_backend
    forced_attn_backend = attn_backend


def get_global_forced_attn_backend() -> Optional[_Backend]:
    '''
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
    '''
    return forced_attn_backend


77
78
79
80
81
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
82
    is_attention_free: bool,
83
    is_blocksparse: bool = False,
84
) -> Type[AttentionBackend]:
85
    """Selects which attention backend to use and lazily imports it."""
Joe Runde's avatar
Joe Runde committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    # Accessing envs.* behind an @lru_cache decorator can cause the wrong
    # value to be returned from the cache if the value changes between calls.
    # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
    # private function.
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
        kv_cache_dtype=kv_cache_dtype,
        block_size=block_size,
        is_attention_free=is_attention_free,
        is_blocksparse=is_blocksparse,
        use_v1=envs.VLLM_USE_V1,
    )


@lru_cache(maxsize=None)
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
    is_attention_free: bool,
    is_blocksparse: bool = False,
    use_v1: bool = False,
) -> Type[AttentionBackend]:
111
112
113
114
115
    if is_blocksparse:
        logger.info("Using BlocksparseFlashAttention backend.")
        from vllm.attention.backends.blocksparse_attn import (
            BlocksparseFlashAttentionBackend)
        return BlocksparseFlashAttentionBackend
116

117
    backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
Joe Runde's avatar
Joe Runde committed
118
                                is_attention_free, use_v1)
119
    if backend == _Backend.FLASH_ATTN:
120
        logger.info("Using Flash Attention backend.")
121
122
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
123
        return FlashAttentionBackend
124
125
126
127
    if backend == _Backend.FLASH_ATTN_VLLM_V1:
        from vllm.v1.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend as FlashAttentionBackendV1)
        return FlashAttentionBackendV1
128
    if backend == _Backend.XFORMERS:
129
        logger.info("Using XFormers backend.")
130
131
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
132
        return XFormersBackend
133
134
135
136
137
138
    elif backend == _Backend.ROCM_FLASH:
        logger.info("Using ROCmFlashAttention backend.")
        from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
            ROCmFlashAttentionBackend)
        return ROCmFlashAttentionBackend
    elif backend == _Backend.TORCH_SDPA:
139
        assert current_platform.is_cpu(), RuntimeError(
140
            "Torch SDPA backend is only used for the CPU device.")
141
142
143
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
144
145
146
147
    elif backend == _Backend.OPENVINO:
        logger.info("Using OpenVINO Attention backend.")
        from vllm.attention.backends.openvino import OpenVINOAttentionBackend
        return OpenVINOAttentionBackend
148
    elif backend == _Backend.IPEX:
149
        assert current_platform.is_xpu(), RuntimeError(
150
151
152
153
            "IPEX attention backend is only used for the XPU device.")
        logger.info("Using IPEX attention backend.")
        from vllm.attention.backends.ipex_attn import IpexAttnBackend
        return IpexAttnBackend
154
155
156
157
    elif backend == _Backend.FLASHINFER:
        logger.info("Using Flashinfer backend.")
        from vllm.attention.backends.flashinfer import FlashInferBackend
        return FlashInferBackend
158
159
160
161
    elif backend == _Backend.HPU_ATTN:
        logger.info("Using HPUAttention backend.")
        from vllm.attention.backends.hpu_attn import HPUAttentionBackend
        return HPUAttentionBackend
162
163
164
165
    elif backend == _Backend.PALLAS:
        logger.info("Using Pallas backend.")
        from vllm.attention.backends.pallas import PallasAttentionBackend
        return PallasAttentionBackend
166
167
168
169
    elif backend == _Backend.NO_ATTENTION:
        from vllm.attention.backends.placeholder_attn import (
            PlaceholderAttentionBackend)
        return PlaceholderAttentionBackend
170
171
    else:
        raise ValueError("Invalid attention backend.")
172
173


Joe Runde's avatar
Joe Runde committed
174
175
176
177
178
179
def which_attn_to_use(head_size: int,
                      dtype: torch.dtype,
                      kv_cache_dtype: Optional[str],
                      block_size: int,
                      is_attention_free: bool,
                      use_v1: bool = False) -> _Backend:
180
    """Returns which flash attention backend to use."""
181
182
183
    # Default case.
    selected_backend = _Backend.FLASH_ATTN

184
185
186
187
188
    # If there are no attention layers (e.g. we are running Mamba),
    # use the placeholder NO_ATTENTION
    if is_attention_free:
        return _Backend.NO_ATTENTION

189
190
191
192
193
194
195
196
197
198
199
200
201
202
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
    backend_by_global_setting: Optional[_Backend] = (
        get_global_forced_attn_backend())
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
            selected_backend = backend_name_to_enum(backend_by_env_var)
203

204
205
206
207
208
    # get device-specific default attn_backend
    default_backend = current_platform.get_default_attn_backend(
        selected_backend)
    if default_backend is not None:
        return default_backend
209

Joe Runde's avatar
Joe Runde committed
210
    if use_v1:
211
212
        return _Backend.FLASH_ATTN_VLLM_V1

213
214
    # FlashAttn in NVIDIA GPUs.
    if selected_backend == _Backend.FLASH_ATTN:
215
        if not current_platform.has_device_capability(80):
216
217
218
219
220
221
222
223
224
225
226
227
228
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            selected_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            selected_backend = _Backend.XFORMERS
        elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
            logger.info(
                "Cannot use FlashAttention-2 backend for FP8 KV cache.")
229
230
231
232
            logger.warning(
                "Please use FlashInfer backend with FP8 KV Cache for "
                "better performance by setting environment variable  "
                "VLLM_ATTENTION_BACKEND=FLASHINFER")
233
234
235
236
237
238
239
240
241
242
            selected_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            selected_backend = _Backend.XFORMERS

    # FlashAttn is valid for the model, checking if the package is installed.
    if selected_backend == _Backend.FLASH_ATTN:
        try:
243
            import vllm.vllm_flash_attn  # noqa: F401
244
245
246
247
248
249
250
251
252
253
254
255
            from vllm.attention.backends.flash_attn import (  # noqa: F401
                FlashAttentionBackend)

            supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
            if head_size not in supported_sizes:
                logger.info(
                    "Cannot use FlashAttention-2 backend for head size %d.",
                    head_size)
                selected_backend = _Backend.XFORMERS
        except ImportError:
            logger.info(
                "Cannot use FlashAttention-2 backend because the "
256
257
258
                "vllm.vllm_flash_attn package is not found. "
                "Make sure that vllm_flash_attn was built and installed "
                "(on by default).")
259
260
261
            selected_backend = _Backend.XFORMERS

    return selected_backend
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293


@contextmanager
def global_force_attn_backend_context_manager(
        attn_backend: _Backend) -> Generator[None, None, None]:
    '''
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
    '''

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)