selector.py 2.61 KB
Newer Older
1
import enum
2
from functools import lru_cache
3
from typing import Type
4
5
6
7
8

import torch

from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
9
from vllm.utils import is_cpu, is_hip
10
11
12
13

logger = init_logger(__name__)


14
15
16
17
18
19
20
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()


21
@lru_cache(maxsize=None)
22
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
23
24
    backend = _which_attn_to_use(dtype)
    if backend == _Backend.FLASH_ATTN:
25
        logger.info("Using FlashAttention backend.")
26
27
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
28
        return FlashAttentionBackend
29
    elif backend == _Backend.XFORMERS:
30
        logger.info("Using XFormers backend.")
31
32
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
33
        return XFormersBackend
34
35
36
37
38
39
40
41
42
43
44
    elif backend == _Backend.ROCM_FLASH:
        logger.info("Using ROCmFlashAttention backend.")
        from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
            ROCmFlashAttentionBackend)
        return ROCmFlashAttentionBackend
    elif backend == _Backend.TORCH_SDPA:
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
    else:
        raise ValueError("Invalid attention backend.")
45
46


47
48
49
50
51
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
    """Returns which flash attention backend to use."""
    if is_cpu():
        return _Backend.TORCH_SDPA

52
53
    if is_hip():
        # AMD GPUs.
54
55
56
57
58
59
        if torch.cuda.get_device_capability()[0] != 9:
            # not Instinct series GPUs.
            logger.info("flash_atten is not supported on NAVI GPUs.")
        return _Backend.ROCM_FLASH

    # NVIDIA GPUs.
60
61
62
63
    if torch.cuda.get_device_capability()[0] < 8:
        # Volta and Turing NVIDIA GPUs.
        logger.info("Cannot use FlashAttention backend for Volta and Turing "
                    "GPUs.")
64
65
        return _Backend.XFORMERS

66
67
68
    if dtype not in (torch.float16, torch.bfloat16):
        logger.info("Cannot use FlashAttention backend for dtype other than "
                    "torch.float16 or torch.bfloat16.")
69
        return _Backend.XFORMERS
70
71
72
73

    try:
        import flash_attn  # noqa: F401
    except ImportError:
74
        logger.info(
75
76
77
78
            "Cannot use FlashAttention backend because the flash_attn package "
            "is not found. Please install it for better performance.")
        return _Backend.XFORMERS
    return _Backend.FLASH_ATTN