selector.py 6.14 KB
Newer Older
1
import enum
2
from functools import lru_cache
3
from typing import Optional, Type
4
5
6

import torch

7
import vllm.envs as envs
8
9
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
10
from vllm.utils import is_cpu, is_hip
11
12
13
14

logger = init_logger(__name__)


15
16
17
18
19
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()
20
    FLASHINFER = enum.auto()
21
22


23
@lru_cache(maxsize=None)
24
25
26
27
28
29
30
31
def get_attn_backend(
    num_heads: int,
    head_size: int,
    num_kv_heads: int,
    sliding_window: Optional[int],
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
32
    is_blocksparse: bool = False,
33
) -> Type[AttentionBackend]:
34
    """Selects which attention backend to use and lazily imports it."""
35
36
37
38
39
40

    if is_blocksparse:
        logger.info("Using BlocksparseFlashAttention backend.")
        from vllm.attention.backends.blocksparse_attn import (
            BlocksparseFlashAttentionBackend)
        return BlocksparseFlashAttentionBackend
41

42
43
44
    backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
                                sliding_window, dtype, kv_cache_dtype,
                                block_size)
45
    if backend == _Backend.FLASH_ATTN:
46
47
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
48
        return FlashAttentionBackend
49
    if backend == _Backend.XFORMERS:
50
        logger.info("Using XFormers backend.")
51
52
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
53
        return XFormersBackend
54
55
56
57
58
59
60
61
62
    elif backend == _Backend.ROCM_FLASH:
        logger.info("Using ROCmFlashAttention backend.")
        from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
            ROCmFlashAttentionBackend)
        return ROCmFlashAttentionBackend
    elif backend == _Backend.TORCH_SDPA:
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
63
64
    elif backend == _Backend.FLASHINFER:
        logger.info("Using Flashinfer backend.")
65
66
        logger.warning("Eager mode is required for the Flashinfer backend. "
                       "Please make sure --enforce-eager is set.")
67
68
        from vllm.attention.backends.flashinfer import FlashInferBackend
        return FlashInferBackend
69
70
    else:
        raise ValueError("Invalid attention backend.")
71
72


73
def which_attn_to_use(
74
75
76
77
78
79
80
81
    num_heads: int,
    head_size: int,
    num_kv_heads: int,
    sliding_window: Optional[int],
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
) -> _Backend:
82
    """Returns which flash attention backend to use."""
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    # Default case.
    selected_backend = _Backend.FLASH_ATTN

    # Check the environment variable and override if specified
    backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
    if backend_by_env_var is not None:
        backend_members = _Backend.__members__
        if backend_by_env_var not in backend_members:
            raise ValueError(
                f"Invalid attention backend '{backend_by_env_var}'. "
                f"Available backends: {', '.join(backend_members)} "
                "(case-sensitive).")
        selected_backend = _Backend[backend_by_env_var]

98
    if is_cpu():
99
100
        if selected_backend != _Backend.TORCH_SDPA:
            logger.info("Cannot use %s backend on CPU.", selected_backend)
101
102
        return _Backend.TORCH_SDPA

103
104
    if is_hip():
        # AMD GPUs.
105
106
107
108
109
110
111
112
        selected_backend = (_Backend.ROCM_FLASH if selected_backend
                            == _Backend.FLASH_ATTN else selected_backend)
        if selected_backend == _Backend.ROCM_FLASH:
            if torch.cuda.get_device_capability()[0] != 9:
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
113
114
        return _Backend.ROCM_FLASH

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    # FlashAttn in NVIDIA GPUs.
    if selected_backend == _Backend.FLASH_ATTN:
        if torch.cuda.get_device_capability()[0] < 8:
            # Volta and Turing NVIDIA GPUs.
            logger.info(
                "Cannot use FlashAttention-2 backend for Volta and Turing "
                "GPUs.")
            selected_backend = _Backend.XFORMERS
        elif dtype not in (torch.float16, torch.bfloat16):
            logger.info(
                "Cannot use FlashAttention-2 backend for dtype other than "
                "torch.float16 or torch.bfloat16.")
            selected_backend = _Backend.XFORMERS
        elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
            logger.info(
                "Cannot use FlashAttention-2 backend for FP8 KV cache.")
            selected_backend = _Backend.XFORMERS
        elif block_size % 16 != 0:
            logger.info(
                "Cannot use FlashAttention-2 backend for block size not "
                "divisible by 16.")
            selected_backend = _Backend.XFORMERS
        elif sliding_window is not None:
            logger.info(
                "Cannot use FlashAttention-2 backend due to sliding window.")
            selected_backend = _Backend.XFORMERS

    # FlashAttn is valid for the model, checking if the package is installed.
    if selected_backend == _Backend.FLASH_ATTN:
        try:
            import vllm_flash_attn  # noqa: F401

            from vllm.attention.backends.flash_attn import (  # noqa: F401
                FlashAttentionBackend)

            supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
            if head_size not in supported_sizes:
                logger.info(
                    "Cannot use FlashAttention-2 backend for head size %d.",
                    head_size)
                selected_backend = _Backend.XFORMERS
        except ImportError:
            logger.info(
                "Cannot use FlashAttention-2 backend because the "
                "vllm_flash_attn package is not found. "
                "`pip install vllm-flash-attn` for better performance.")
            selected_backend = _Backend.XFORMERS

    return selected_backend