selector.py 2.83 KB
Newer Older
1
import enum
2
import os
3
from functools import lru_cache
4
from typing import Type
5
6
7
8
9

import torch

from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
10
from vllm.utils import is_cpu, is_hip
11
12
13

logger = init_logger(__name__)

14
15
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"

16

17
18
19
20
21
22
23
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()


24
@lru_cache(maxsize=None)
25
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
26
27
    backend = _which_attn_to_use(dtype)
    if backend == _Backend.FLASH_ATTN:
28
        logger.info("Using FlashAttention backend.")
29
30
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
31
        return FlashAttentionBackend
32
    elif backend == _Backend.XFORMERS:
33
        logger.info("Using XFormers backend.")
34
35
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
36
        return XFormersBackend
37
38
39
40
41
42
43
44
45
46
47
    elif backend == _Backend.ROCM_FLASH:
        logger.info("Using ROCmFlashAttention backend.")
        from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
            ROCmFlashAttentionBackend)
        return ROCmFlashAttentionBackend
    elif backend == _Backend.TORCH_SDPA:
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
    else:
        raise ValueError("Invalid attention backend.")
48
49


50
51
52
53
54
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
    """Returns which flash attention backend to use."""
    if is_cpu():
        return _Backend.TORCH_SDPA

55
56
    if is_hip():
        # AMD GPUs.
57
58
59
60
61
62
        if torch.cuda.get_device_capability()[0] != 9:
            # not Instinct series GPUs.
            logger.info("flash_atten is not supported on NAVI GPUs.")
        return _Backend.ROCM_FLASH

    # NVIDIA GPUs.
63
64
65
66
    if torch.cuda.get_device_capability()[0] < 8:
        # Volta and Turing NVIDIA GPUs.
        logger.info("Cannot use FlashAttention backend for Volta and Turing "
                    "GPUs.")
67
68
        return _Backend.XFORMERS

69
70
71
    if dtype not in (torch.float16, torch.bfloat16):
        logger.info("Cannot use FlashAttention backend for dtype other than "
                    "torch.float16 or torch.bfloat16.")
72
        return _Backend.XFORMERS
73
74
75
76

    try:
        import flash_attn  # noqa: F401
    except ImportError:
77
        logger.info(
78
79
80
            "Cannot use FlashAttention backend because the flash_attn package "
            "is not found. Please install it for better performance.")
        return _Backend.XFORMERS
81
82
83
84
85
86

    backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
    if backend_by_env_var is not None:
        return _Backend[backend_by_env_var]

    # Default case.
87
    return _Backend.FLASH_ATTN