selector.py 12.1 KB
Newer Older
1
import enum
2
3
import os
from contextlib import contextmanager
4
from functools import lru_cache
5
from typing import Generator, Optional, Type
6
7
8

import torch

9
import vllm.envs as envs
10
11
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
12
from vllm.platforms import current_platform
13
from vllm.utils import STR_BACKEND_ENV_VAR
14
15
16
17

logger = init_logger(__name__)


18
19
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
20
    FLASH_ATTN_VLLM_V1 = enum.auto()
21
22
23
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()
24
    OPENVINO = enum.auto()
25
    FLASHINFER = enum.auto()
26
    HPU_ATTN = enum.auto()
27
    PALLAS = enum.auto()
28
    IPEX = enum.auto()
29
    NO_ATTENTION = enum.auto()
30
31


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def backend_name_to_enum(backend_name: str) -> _Backend:
    assert backend_name is not None

    backend_members = _Backend.__members__
    if backend_name not in backend_members:
        raise ValueError(f"Invalid attention backend '{backend_name}'. "
                         f"Available backends: {', '.join(backend_members)} "
                         "(case-sensitive).")

    return _Backend[backend_name]


def get_env_variable_attn_backend() -> Optional[_Backend]:
    '''
    Get the backend override specified by the vLLM attention
    backend environment variable, if one is specified.

    Returns:

    * _Backend enum value if an override is specified
    * None otherwise
    '''
    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
    return (None
            if backend_name is None else backend_name_to_enum(backend_name))


# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: Optional[_Backend] = None


def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
    '''
    Force all attention operations to use a specified backend.

    Passing `None` for the argument re-enables automatic
    backend selection.,

    Arguments:

    * attn_backend: backend selection (None to revert to auto)
    '''
    global forced_attn_backend
    forced_attn_backend = attn_backend


def get_global_forced_attn_backend() -> Optional[_Backend]:
    '''
    Get the currently-forced choice of attention backend,
    or None if auto-selection is currently enabled.
    '''
    return forced_attn_backend


92
93
94
95
96
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
97
    is_attention_free: bool,
98
    is_blocksparse: bool = False,
99
) -> Type[AttentionBackend]:
100
    """Selects which attention backend to use and lazily imports it."""
Joe Runde's avatar
Joe Runde committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    # Accessing envs.* behind an @lru_cache decorator can cause the wrong
    # value to be returned from the cache if the value changes between calls.
    # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
    # private function.
    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
        kv_cache_dtype=kv_cache_dtype,
        block_size=block_size,
        is_attention_free=is_attention_free,
        is_blocksparse=is_blocksparse,
        use_v1=envs.VLLM_USE_V1,
    )


@lru_cache(maxsize=None)
def _cached_get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
    is_attention_free: bool,
    is_blocksparse: bool = False,
    use_v1: bool = False,
) -> Type[AttentionBackend]:
126
127
128
129
130
    if is_blocksparse:
        logger.info("Using BlocksparseFlashAttention backend.")
        from vllm.attention.backends.blocksparse_attn import (
            BlocksparseFlashAttentionBackend)
        return BlocksparseFlashAttentionBackend
131

132
    backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
Joe Runde's avatar
Joe Runde committed
133
                                is_attention_free, use_v1)
134
    if backend == _Backend.FLASH_ATTN:
135
        logger.info("Using Flash Attention backend.")
136
137
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
138
        return FlashAttentionBackend
139
140
141
142
    if backend == _Backend.FLASH_ATTN_VLLM_V1:
        from vllm.v1.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend as FlashAttentionBackendV1)
        return FlashAttentionBackendV1
143
    if backend == _Backend.XFORMERS:
144
        logger.info("Using XFormers backend.")
145
146
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
147
        return XFormersBackend
148
149
150
151
152
153
    elif backend == _Backend.ROCM_FLASH:
        logger.info("Using ROCmFlashAttention backend.")
        from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
            ROCmFlashAttentionBackend)
        return ROCmFlashAttentionBackend
    elif backend == _Backend.TORCH_SDPA:
154
        assert current_platform.is_cpu(), RuntimeError(
155
            "Torch SDPA backend is only used for the CPU device.")
156
157
158
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
159
160
161
162
    elif backend == _Backend.OPENVINO:
        logger.info("Using OpenVINO Attention backend.")
        from vllm.attention.backends.openvino import OpenVINOAttentionBackend
        return OpenVINOAttentionBackend
163
    elif backend == _Backend.IPEX:
164
        assert current_platform.is_xpu(), RuntimeError(
165
166
167
168
            "IPEX attention backend is only used for the XPU device.")
        logger.info("Using IPEX attention backend.")
        from vllm.attention.backends.ipex_attn import IpexAttnBackend
        return IpexAttnBackend
169
170
171
172
    elif backend == _Backend.FLASHINFER:
        logger.info("Using Flashinfer backend.")
        from vllm.attention.backends.flashinfer import FlashInferBackend
        return FlashInferBackend
173
174
175
176
    elif backend == _Backend.HPU_ATTN:
        logger.info("Using HPUAttention backend.")
        from vllm.attention.backends.hpu_attn import HPUAttentionBackend
        return HPUAttentionBackend
177
178
179
180
    elif backend == _Backend.PALLAS:
        logger.info("Using Pallas backend.")
        from vllm.attention.backends.pallas import PallasAttentionBackend
        return PallasAttentionBackend
181
182
183
184
    elif backend == _Backend.NO_ATTENTION:
        from vllm.attention.backends.placeholder_attn import (
            PlaceholderAttentionBackend)
        return PlaceholderAttentionBackend
185
186
    else:
        raise ValueError("Invalid attention backend.")
187
188


Joe Runde's avatar
Joe Runde committed
189
190
191
192
193
194
def which_attn_to_use(head_size: int,
                      dtype: torch.dtype,
                      kv_cache_dtype: Optional[str],
                      block_size: int,
                      is_attention_free: bool,
                      use_v1: bool = False) -> _Backend:
195
    """Returns which flash attention backend to use."""
196
197
198
    # Default case.
    selected_backend = _Backend.FLASH_ATTN

199
200
201
202
203
    # If there are no attention layers (e.g. we are running Mamba),
    # use the placeholder NO_ATTENTION
    if is_attention_free:
        return _Backend.NO_ATTENTION

204
205
206
207
208
209
210
211
212
213
214
215
216
217
    # Check whether a particular choice of backend was
    # previously forced.
    #
    # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
    # ENVIRONMENT VARIABLE.
    backend_by_global_setting: Optional[_Backend] = (
        get_global_forced_attn_backend())
    if backend_by_global_setting is not None:
        selected_backend = backend_by_global_setting
    else:
        # Check the environment variable and override if specified
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
            selected_backend = backend_name_to_enum(backend_by_env_var)
218

219
    if current_platform.is_cpu():
220
221
        if selected_backend != _Backend.TORCH_SDPA:
            logger.info("Cannot use %s backend on CPU.", selected_backend)
222
223
        return _Backend.TORCH_SDPA

224
    if current_platform.is_openvino():
225
226
227
228
        if selected_backend != _Backend.OPENVINO:
            logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
        return _Backend.OPENVINO

229
    if current_platform.is_xpu():
230
231
232
233
        if selected_backend != _Backend.IPEX:
            logger.info("Cannot use %s backend on XPU.", selected_backend)
        return _Backend.IPEX

234
    if current_platform.is_tpu():
235
236
237
238
        if selected_backend != _Backend.PALLAS:
            logger.info("Cannot use %s backend on TPU.", selected_backend)
        return _Backend.PALLAS

239
    if current_platform.is_rocm():
240
        # AMD GPUs.
241
242
243
        selected_backend = (_Backend.ROCM_FLASH if selected_backend
                            == _Backend.FLASH_ATTN else selected_backend)
        if selected_backend == _Backend.ROCM_FLASH:
244
            if not current_platform.has_device_capability(90):
245
246
247
248
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
249
250
        return _Backend.ROCM_FLASH

251
252
253
    if current_platform.is_hpu():
        return _Backend.HPU_ATTN

Joe Runde's avatar
Joe Runde committed
254
    if use_v1:
255
256
        return _Backend.FLASH_ATTN_VLLM_V1

257
258
    # FlashAttn in NVIDIA GPUs.
    if selected_backend == _Backend.FLASH_ATTN:
259
        if not current_platform.has_device_capability(80):
260
261
262
263
264
265
266
267
268
269
270
271
272
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            selected_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            selected_backend = _Backend.XFORMERS
        elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
            logger.info(
                "Cannot use FlashAttention-2 backend for FP8 KV cache.")
273
274
275
276
            logger.warning(
                "Please use FlashInfer backend with FP8 KV Cache for "
                "better performance by setting environment variable  "
                "VLLM_ATTENTION_BACKEND=FLASHINFER")
277
278
279
280
281
282
283
284
285
286
            selected_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            selected_backend = _Backend.XFORMERS

    # FlashAttn is valid for the model, checking if the package is installed.
    if selected_backend == _Backend.FLASH_ATTN:
        try:
287
            import vllm.vllm_flash_attn  # noqa: F401
288
289
290
291
292
293
294
295
296
297
298
299
            from vllm.attention.backends.flash_attn import (  # noqa: F401
                FlashAttentionBackend)

            supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
            if head_size not in supported_sizes:
                logger.info(
                    "Cannot use FlashAttention-2 backend for head size %d.",
                    head_size)
                selected_backend = _Backend.XFORMERS
        except ImportError:
            logger.info(
                "Cannot use FlashAttention-2 backend because the "
300
301
302
                "vllm.vllm_flash_attn package is not found. "
                "Make sure that vllm_flash_attn was built and installed "
                "(on by default).")
303
304
305
            selected_backend = _Backend.XFORMERS

    return selected_backend
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337


@contextmanager
def global_force_attn_backend_context_manager(
        attn_backend: _Backend) -> Generator[None, None, None]:
    '''
    Globally force a vLLM attention backend override within a
    context manager, reverting the global attention backend
    override to its prior state upon exiting the context
    manager.

    Arguments:

    * attn_backend: attention backend to force

    Returns:

    * Generator
    '''

    # Save the current state of the global backend override (if any)
    original_value = get_global_forced_attn_backend()

    # Globally force the new backend override
    global_force_attn_backend(attn_backend)

    # Yield control back to the enclosed code block
    try:
        yield
    finally:
        # Revert the original global backend override, if any
        global_force_attn_backend(original_value)