selector.py 11.3 KB
Newer Older
1
import enum
2
3
import os
from contextlib import contextmanager
4
from functools import lru_cache
5
from typing import Generator, Optional, Type
6
7
8

import torch

9
import vllm.envs as envs
10
11
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
12
from vllm.platforms import current_platform
13
from vllm.utils import STR_BACKEND_ENV_VAR
14
15
16
17

logger = init_logger(__name__)


18
19
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
20
    FLASH_ATTN_VLLM_V1 = enum.auto()
21
22
23
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()
24
    OPENVINO = enum.auto()
25
    FLASHINFER = enum.auto()
26
    HPU_ATTN = enum.auto()
27
    PALLAS = enum.auto()
28
    IPEX = enum.auto()
29
    NO_ATTENTION = enum.auto()
30
31


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def backend_name_to_enum(backend_name: str) -> _Backend:
    assert backend_name is not None

    backend_members = _Backend.__members__
    if backend_name not in backend_members:
        raise ValueError(f"Invalid attention backend '{backend_name}'. "
                         f"Available backends: {', '.join(backend_members)} "
                         "(case-sensitive).")

    return _Backend[backend_name]


def get_env_variable_attn_backend() -> Optional[_Backend]:
    '''
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

    * _Backend enum value if an override is specified
    * None otherwise
    '''
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
    return (None
            if backend_name is None else backend_name_to_enum(backend_name))


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: Optional[_Backend] = None


def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
    '''
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
    '''
    global forced_attn_backend
    forced_attn_backend = attn_backend


def get_global_forced_attn_backend() -> Optional[_Backend]:
    '''
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
    '''
    return forced_attn_backend


92
@lru_cache(maxsize=None)
93
94
95
96
97
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
98
    is_attention_free: bool,
99
    is_blocksparse: bool = False,
100
) -> Type[AttentionBackend]:
101
    """Selects which attention backend to use and lazily imports it."""
102
103
104
105
106
    if is_blocksparse:
        logger.info("Using BlocksparseFlashAttention backend.")
        from vllm.attention.backends.blocksparse_attn import (
            BlocksparseFlashAttentionBackend)
        return BlocksparseFlashAttentionBackend
107

108
109
    backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
                                is_attention_free)
110
    if backend == _Backend.FLASH_ATTN:
111
        logger.info("Using Flash Attention backend.")
112
113
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
114
        return FlashAttentionBackend
115
116
117
118
    if backend == _Backend.FLASH_ATTN_VLLM_V1:
        from vllm.v1.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend as FlashAttentionBackendV1)
        return FlashAttentionBackendV1
119
    if backend == _Backend.XFORMERS:
120
        logger.info("Using XFormers backend.")
121
122
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
123
        return XFormersBackend
124
125
126
127
128
129
    elif backend == _Backend.ROCM_FLASH:
        logger.info("Using ROCmFlashAttention backend.")
        from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
            ROCmFlashAttentionBackend)
        return ROCmFlashAttentionBackend
    elif backend == _Backend.TORCH_SDPA:
130
        assert current_platform.is_cpu(), RuntimeError(
131
            "Torch SDPA backend is only used for the CPU device.")
132
133
134
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
135
136
137
138
    elif backend == _Backend.OPENVINO:
        logger.info("Using OpenVINO Attention backend.")
        from vllm.attention.backends.openvino import OpenVINOAttentionBackend
        return OpenVINOAttentionBackend
139
    elif backend == _Backend.IPEX:
140
        assert current_platform.is_xpu(), RuntimeError(
141
142
143
144
            "IPEX attention backend is only used for the XPU device.")
        logger.info("Using IPEX attention backend.")
        from vllm.attention.backends.ipex_attn import IpexAttnBackend
        return IpexAttnBackend
145
146
147
148
    elif backend == _Backend.FLASHINFER:
        logger.info("Using Flashinfer backend.")
        from vllm.attention.backends.flashinfer import FlashInferBackend
        return FlashInferBackend
149
150
151
152
    elif backend == _Backend.HPU_ATTN:
        logger.info("Using HPUAttention backend.")
        from vllm.attention.backends.hpu_attn import HPUAttentionBackend
        return HPUAttentionBackend
153
154
155
156
    elif backend == _Backend.PALLAS:
        logger.info("Using Pallas backend.")
        from vllm.attention.backends.pallas import PallasAttentionBackend
        return PallasAttentionBackend
157
158
159
160
    elif backend == _Backend.NO_ATTENTION:
        from vllm.attention.backends.placeholder_attn import (
            PlaceholderAttentionBackend)
        return PlaceholderAttentionBackend
161
162
    else:
        raise ValueError("Invalid attention backend.")
163
164


165
def which_attn_to_use(
166
167
168
169
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
170
    is_attention_free: bool,
171
) -> _Backend:
172
    """Returns which flash attention backend to use."""
173
174
175
    # Default case.
    selected_backend = _Backend.FLASH_ATTN

176
177
178
179
180
    # If there are no attention layers (e.g. we are running Mamba),
    # use the placeholder NO_ATTENTION
    if is_attention_free:
        return _Backend.NO_ATTENTION

181
182
183
184
185
186
187
188
189
190
191
192
193
194
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
    backend_by_global_setting: Optional[_Backend] = (
        get_global_forced_attn_backend())
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
            selected_backend = backend_name_to_enum(backend_by_env_var)
195

196
    if current_platform.is_cpu():
197
198
        if selected_backend != _Backend.TORCH_SDPA:
            logger.info("Cannot use %s backend on CPU.", selected_backend)
199
200
        return _Backend.TORCH_SDPA

201
    if current_platform.is_openvino():
202
203
204
205
        if selected_backend != _Backend.OPENVINO:
            logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
        return _Backend.OPENVINO

206
    if current_platform.is_xpu():
207
208
209
210
        if selected_backend != _Backend.IPEX:
            logger.info("Cannot use %s backend on XPU.", selected_backend)
        return _Backend.IPEX

211
    if current_platform.is_tpu():
212
213
214
215
        if selected_backend != _Backend.PALLAS:
            logger.info("Cannot use %s backend on TPU.", selected_backend)
        return _Backend.PALLAS

216
    if current_platform.is_rocm():
217
        # AMD GPUs.
218
219
220
        selected_backend = (_Backend.ROCM_FLASH if selected_backend
                            == _Backend.FLASH_ATTN else selected_backend)
        if selected_backend == _Backend.ROCM_FLASH:
221
            if not current_platform.has_device_capability(90):
222
223
224
225
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
226
227
        return _Backend.ROCM_FLASH

228
229
230
    if current_platform.is_hpu():
        return _Backend.HPU_ATTN

231
232
233
    if envs.VLLM_USE_V1:
        return _Backend.FLASH_ATTN_VLLM_V1

234
235
    # FlashAttn in NVIDIA GPUs.
    if selected_backend == _Backend.FLASH_ATTN:
236
        if not current_platform.has_device_capability(80):
237
238
239
240
241
242
243
244
245
246
247
248
249
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            selected_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            selected_backend = _Backend.XFORMERS
        elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
            logger.info(
                "Cannot use FlashAttention-2 backend for FP8 KV cache.")
250
251
252
253
            logger.warning(
                "Please use FlashInfer backend with FP8 KV Cache for "
                "better performance by setting environment variable  "
                "VLLM_ATTENTION_BACKEND=FLASHINFER")
254
255
256
257
258
259
260
261
262
263
            selected_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            selected_backend = _Backend.XFORMERS

    # FlashAttn is valid for the model, checking if the package is installed.
    if selected_backend == _Backend.FLASH_ATTN:
        try:
264
            import vllm.vllm_flash_attn  # noqa: F401
265
266
267
268
269
270
271
272
273
274
275
276
            from vllm.attention.backends.flash_attn import (  # noqa: F401
                FlashAttentionBackend)

            supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
            if head_size not in supported_sizes:
                logger.info(
                    "Cannot use FlashAttention-2 backend for head size %d.",
                    head_size)
                selected_backend = _Backend.XFORMERS
        except ImportError:
            logger.info(
                "Cannot use FlashAttention-2 backend because the "
277
278
279
                "vllm.vllm_flash_attn package is not found. "
                "Make sure that vllm_flash_attn was built and installed "
                "(on by default).")
280
281
282
            selected_backend = _Backend.XFORMERS

    return selected_backend
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314


@contextmanager
def global_force_attn_backend_context_manager(
        attn_backend: _Backend) -> Generator[None, None, None]:
    '''
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
    '''

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)