fa_utils.py 4.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from typing import Any

6
from vllm.logger import init_logger
7
from vllm.platforms import current_platform
8
9
10

logger = init_logger(__name__)

11
if current_platform.is_cuda():
12
    from vllm._custom_ops import reshape_and_cache_flash
13
14
15
16
    from vllm.vllm_flash_attn import (  # type: ignore[attr-defined]
        flash_attn_varlen_func,
        get_scheduler_metadata,
    )
17
elif current_platform.is_xpu():
18
    from vllm._ipex_ops import ipex_ops
19

20
    reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash
21
22
    flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func  # type: ignore[assignment]
    get_scheduler_metadata = ipex_ops.get_scheduler_metadata  # type: ignore[assignment]
23

24
25
elif current_platform.is_rocm():
    try:
26
        from vllm._custom_ops import reshape_and_cache_cuda
27
        from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
28
29
    except ImportError:

30
        def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any:  # type: ignore[no-redef,misc]
31
32
33
34
            raise ImportError(
                "ROCm platform requires upstream flash-attn "
                "to be installed. Please install flash-attn first."
            )
35

36
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
37
38
    # import here to avoid circular dependencies
    from vllm.platforms import current_platform
39

40
41
    if current_platform.is_xpu():
        return 2
42
43
    if current_platform.is_rocm():
        # ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
44
        return 2 # None
45
46
    try:
        from vllm.vllm_flash_attn.flash_attn_interface import (
47
48
49
50
            fa_version_unsupported_reason,
            is_fa_version_supported,
        )

51
52
53
54
55
        device_capability = current_platform.get_device_capability()

        assert device_capability is not None

        # 1. default version depending on platform
56
57
58
        fa_version = (
            3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
        )
59

60
        # 2. override if passed by environment or config
61
        from vllm.config import get_current_vllm_config_or_none
62

63
64
65
66
67
        vllm_config = get_current_vllm_config_or_none()
        if (
            vllm_config is not None
            and vllm_config.attention_config.flash_attn_version is not None
        ):
68
            fa_version = vllm_config.attention_config.flash_attn_version
69
70
71

        # 3. fallback for unsupported combinations
        if device_capability.major == 10 and fa_version == 3:
72
            logger.warning_once(
73
                "Cannot use FA version 3 on Blackwell platform, "
74
75
                "defaulting to FA version 2."
            )
76
77
78
            fa_version = 2

        if requires_alibi and fa_version == 3:
79
80
81
            logger.warning_once(
                "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
            )
82
83
84
            fa_version = 2

        if not is_fa_version_supported(fa_version):
85
86
87
88
89
            logger.error(
                "Cannot use FA version %d is not supported due to %s",
                fa_version,
                fa_version_unsupported_reason(fa_version),
            )
90
91
92
93
94

        assert is_fa_version_supported(fa_version)
        return fa_version
    except (ImportError, AssertionError):
        return None
95
96
97


def flash_attn_supports_fp8() -> bool:
zhuwenwen's avatar
zhuwenwen committed
98
    if current_platform.is_rocm():
99
        return True
100
101
    return (
        get_flash_attn_version() == 3
102
        and current_platform.is_device_capability_family(90)
103
    )
104
105


106
107
108
109
110
def flash_attn_supports_sinks() -> bool:
    if current_platform.is_xpu():
        return True
    else:
        return get_flash_attn_version() == 3
111
112


113
114
def flash_attn_supports_mla():
    from vllm.platforms import current_platform
115

116
117
118
    if current_platform.is_cuda():
        try:
            from vllm.vllm_flash_attn.flash_attn_interface import (
119
120
121
                is_fa_version_supported,
            )

122
123
124
            return is_fa_version_supported(
                3
            ) and current_platform.is_device_capability_family(90)
125
126
127
128
129
        except (ImportError, AssertionError):
            pass
    return False


130
def is_flash_attn_varlen_func_available() -> bool:
zhuwenwen's avatar
zhuwenwen committed
131
    return current_platform.is_cuda() or current_platform.is_rocm() or current_platform.is_xpu()