selector.py 2.8 KB
Newer Older
1
import enum
2
from functools import lru_cache
3
from typing import Type
4
5
6

import torch

7
import vllm.envs as envs
8
9
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
10
from vllm.utils import is_cpu, is_hip
11
12
13
14

logger = init_logger(__name__)


15
16
17
18
19
20
21
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()


22
@lru_cache(maxsize=None)
23
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
24
25
    backend = _which_attn_to_use(dtype)
    if backend == _Backend.FLASH_ATTN:
26
        logger.info("Using FlashAttention-2 backend.")
27
28
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
29
        return FlashAttentionBackend
30
    elif backend == _Backend.XFORMERS:
31
        logger.info("Using XFormers backend.")
32
33
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
34
        return XFormersBackend
35
36
37
38
39
40
41
42
43
44
45
    elif backend == _Backend.ROCM_FLASH:
        logger.info("Using ROCmFlashAttention backend.")
        from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
            ROCmFlashAttentionBackend)
        return ROCmFlashAttentionBackend
    elif backend == _Backend.TORCH_SDPA:
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
    else:
        raise ValueError("Invalid attention backend.")
46
47


48
49
50
51
52
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
    """Returns which flash attention backend to use."""
    if is_cpu():
        return _Backend.TORCH_SDPA

53
54
    if is_hip():
        # AMD GPUs.
55
56
57
58
59
60
        if torch.cuda.get_device_capability()[0] != 9:
            # not Instinct series GPUs.
            logger.info("flash_atten is not supported on NAVI GPUs.")
        return _Backend.ROCM_FLASH

    # NVIDIA GPUs.
61
62
    if torch.cuda.get_device_capability()[0] < 8:
        # Volta and Turing NVIDIA GPUs.
63
        logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
64
                    "GPUs.")
65
66
        return _Backend.XFORMERS

67
    if dtype not in (torch.float16, torch.bfloat16):
68
        logger.info("Cannot use FlashAttention-2 backend for dtype other than "
69
                    "torch.float16 or torch.bfloat16.")
70
        return _Backend.XFORMERS
71
72
73
74

    try:
        import flash_attn  # noqa: F401
    except ImportError:
75
        logger.info(
76
77
            "Cannot use FlashAttention-2 backend because the flash_attn "
            "package is not found. Please install it for better performance.")
78
        return _Backend.XFORMERS
79

80
    backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
81
82
83
84
    if backend_by_env_var is not None:
        return _Backend[backend_by_env_var]

    # Default case.
85
    return _Backend.FLASH_ATTN