selector.py 3.1 KB
Newer Older
1
import enum
2
from functools import lru_cache
3
from typing import Type
4
5
6

import torch

7
import vllm.envs as envs
8
9
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
10
from vllm.utils import is_cpu, is_hip
11
12
13
14

logger = init_logger(__name__)


15
16
17
18
19
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()
20
    FLASHINFER = enum.auto()
21
22


23
@lru_cache(maxsize=None)
24
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
25
26
    backend = _which_attn_to_use(dtype)
    if backend == _Backend.FLASH_ATTN:
27
        logger.info("Using FlashAttention-2 backend.")
28
29
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
30
        return FlashAttentionBackend
31
    elif backend == _Backend.XFORMERS:
32
        logger.info("Using XFormers backend.")
33
34
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
35
        return XFormersBackend
36
37
38
39
40
41
42
43
44
    elif backend == _Backend.ROCM_FLASH:
        logger.info("Using ROCmFlashAttention backend.")
        from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
            ROCmFlashAttentionBackend)
        return ROCmFlashAttentionBackend
    elif backend == _Backend.TORCH_SDPA:
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
45
46
47
48
49
    elif backend == _Backend.FLASHINFER:
        logger.info("Using Flashinfer backend.")
        logger.warning("Eager mode is enforced for the Flashinfer backend. ")
        from vllm.attention.backends.flashinfer import FlashInferBackend
        return FlashInferBackend
50
51
    else:
        raise ValueError("Invalid attention backend.")
52
53


54
55
56
57
58
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
    """Returns which flash attention backend to use."""
    if is_cpu():
        return _Backend.TORCH_SDPA

59
60
    if is_hip():
        # AMD GPUs.
61
62
63
64
65
66
        if torch.cuda.get_device_capability()[0] != 9:
            # not Instinct series GPUs.
            logger.info("flash_atten is not supported on NAVI GPUs.")
        return _Backend.ROCM_FLASH

    # NVIDIA GPUs.
67
68
    if torch.cuda.get_device_capability()[0] < 8:
        # Volta and Turing NVIDIA GPUs.
69
        logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
70
                    "GPUs.")
71
72
        return _Backend.XFORMERS

73
    if dtype not in (torch.float16, torch.bfloat16):
74
        logger.info("Cannot use FlashAttention-2 backend for dtype other than "
75
                    "torch.float16 or torch.bfloat16.")
76
        return _Backend.XFORMERS
77
78
79
80

    try:
        import flash_attn  # noqa: F401
    except ImportError:
81
        logger.info(
82
83
            "Cannot use FlashAttention-2 backend because the flash_attn "
            "package is not found. Please install it for better performance.")
84
        return _Backend.XFORMERS
85

86
    backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
87
88
89
90
    if backend_by_env_var is not None:
        return _Backend[backend_by_env_var]

    # Default case.
91
    return _Backend.FLASH_ATTN