selector.py 1.79 KB
Newer Older
1
from functools import lru_cache
2
from typing import Type
3
4
5
6
7

import torch

from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
8
from vllm.utils import is_cpu, is_hip
9
10
11
12
13

logger = init_logger(__name__)


@lru_cache(maxsize=None)
14
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
15
16
    if _can_use_flash_attn(dtype):
        logger.info("Using FlashAttention backend.")
17
18
        from vllm.attention.backends.flash_attn import (  # noqa: F401
            FlashAttentionBackend)
19
        return FlashAttentionBackend
20
21
22
23
    elif is_cpu():
        logger.info("Using Torch SDPA backend.")
        from vllm.attention.backends.torch_sdpa import TorchSDPABackend
        return TorchSDPABackend
24
25
    else:
        logger.info("Using XFormers backend.")
26
27
        from vllm.attention.backends.xformers import (  # noqa: F401
            XFormersBackend)
28
29
30
31
32
33
34
35
        return XFormersBackend


def _can_use_flash_attn(dtype: torch.dtype) -> bool:
    if is_hip():
        # AMD GPUs.
        logger.info("Cannot use FlashAttention backend for AMD GPUs.")
        return False
36
37
    if is_cpu():
        return False
38
39
40
41
42
43
44
45
46
47
48
49
50
    if torch.cuda.get_device_capability()[0] < 8:
        # Volta and Turing NVIDIA GPUs.
        logger.info("Cannot use FlashAttention backend for Volta and Turing "
                    "GPUs.")
        return False
    if dtype not in (torch.float16, torch.bfloat16):
        logger.info("Cannot use FlashAttention backend for dtype other than "
                    "torch.float16 or torch.bfloat16.")
        return False

    try:
        import flash_attn  # noqa: F401
    except ImportError:
51
52
53
        logger.info(
            "Cannot use FlashAttention because the package is not found. "
            "Please install it for better performance.")
54
55
        return False
    return True