selector.py 3.6 KB
Newer Older
1
import enum
2
from functools import lru_cache
3
from typing import Optional, Type
4
5
6

import torch

7
import vllm.envs as envs
8
9
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
10
from vllm.utils import is_cpu, is_hip
11
12
13
14

logger = init_logger(__name__)


15
16
17
18
19
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()
20
    FLASHINFER = enum.auto()
21
22


23
@lru_cache(maxsize=None)
24
25
26
27
28
29
30
31
32
33
34
35
def get_attn_backend(
    num_heads: int,
    head_size: int,
    num_kv_heads: int,
    sliding_window: Optional[int],
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
) -> Type[AttentionBackend]:
    backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
                                 sliding_window, dtype, kv_cache_dtype,
                                 block_size)
36
    if backend == _Backend.FLASH_ATTN:
37
        logger.info("Using FlashAttention-2 backend.")
38
39
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
40
        return FlashAttentionBackend
41
    elif backend == _Backend.XFORMERS:
42
        logger.info("Using XFormers backend.")
43
44
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
45
        return XFormersBackend
46
47
48
49
50
51
52
53
54
    elif backend == _Backend.ROCM_FLASH:
        logger.info("Using ROCmFlashAttention backend.")
        from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
            ROCmFlashAttentionBackend)
        return ROCmFlashAttentionBackend
    elif backend == _Backend.TORCH_SDPA:
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
55
56
    elif backend == _Backend.FLASHINFER:
        logger.info("Using Flashinfer backend.")
57
        logger.warning("Eager mode is enforced for the Flashinfer backend.")
58
59
        from vllm.attention.backends.flashinfer import FlashInferBackend
        return FlashInferBackend
60
61
    else:
        raise ValueError("Invalid attention backend.")
62
63


64
65
66
67
68
69
70
71
72
def _which_attn_to_use(
    num_heads: int,
    head_size: int,
    num_kv_heads: int,
    sliding_window: Optional[int],
    dtype: torch.dtype,
    kv_cache_dtype: Optional[str],
    block_size: int,
) -> _Backend:
73
74
75
76
    """Returns which flash attention backend to use."""
    if is_cpu():
        return _Backend.TORCH_SDPA

77
78
    if is_hip():
        # AMD GPUs.
79
80
81
82
83
84
        if torch.cuda.get_device_capability()[0] != 9:
            # not Instinct series GPUs.
            logger.info("flash_atten is not supported on NAVI GPUs.")
        return _Backend.ROCM_FLASH

    # NVIDIA GPUs.
85
86
    if torch.cuda.get_device_capability()[0] < 8:
        # Volta and Turing NVIDIA GPUs.
87
        logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
88
                    "GPUs.")
89
90
        return _Backend.XFORMERS

91
    if dtype not in (torch.float16, torch.bfloat16):
92
        logger.info("Cannot use FlashAttention-2 backend for dtype other than "
93
                    "torch.float16 or torch.bfloat16.")
94
        return _Backend.XFORMERS
95
96

    try:
97
        import vllm_flash_attn  # noqa: F401
98
    except ImportError:
99
        logger.info(
100
101
102
            "Cannot use FlashAttention-2 backend because the vllm_flash_attn "
            "package is not found. `pip install vllm-flash-attn` for better "
            "performance.")
103
        return _Backend.XFORMERS
104

105
    backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
106
107
108
109
    if backend_by_env_var is not None:
        return _Backend[backend_by_env_var]

    # Default case.
110
    return _Backend.FLASH_ATTN