"vllm/vscode:/vscode.git/clone" did not exist on "54c892438479c0a8aec9a10b8db570034af92443"
fa_utils.py 5.85 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from typing import Any

6
from vllm.logger import init_logger
7
from vllm.platforms import current_platform
8
9
10

logger = init_logger(__name__)

11
12
13
14
15
16
# Track whether upstream flash-attn is available on ROCm.
# Set during module initialization and never modified afterwards.
# This module-level flag avoids repeated import attempts and ensures
# consistent behavior (similar to IS_AITER_FOUND in _aiter_ops.py).
_ROCM_FLASH_ATTN_AVAILABLE = False

17
if current_platform.is_cuda():
18
    from vllm._custom_ops import reshape_and_cache_flash
19
20
21
22
23
    from vllm.vllm_flash_attn import (  # type: ignore[attr-defined]
        flash_attn_varlen_func,
        get_scheduler_metadata,
    )

24
elif current_platform.is_xpu():
25
    from vllm import _custom_ops as ops
26
    from vllm._xpu_ops import xpu_ops
27
28

    reshape_and_cache_flash = ops.reshape_and_cache_flash
29
30
    flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func  # type: ignore[assignment]
    get_scheduler_metadata = xpu_ops.get_scheduler_metadata  # type: ignore[assignment]
31
32
elif current_platform.is_rocm():
    try:
33
        from flash_attn import flash_attn_varlen_func  # type: ignore[no-redef]
34
35
36

        # Mark that upstream flash-attn is available on ROCm
        _ROCM_FLASH_ATTN_AVAILABLE = True
37
38
    except ImportError:

39
        def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any:  # type: ignore[no-redef,misc]
40
41
42
43
            raise ImportError(
                "ROCm platform requires upstream flash-attn "
                "to be installed. Please install flash-attn first."
            )
44

45
46
47
48
49
50
51
52
53
    # ROCm doesn't use scheduler metadata (FA3 feature), provide stub
    def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None:  # type: ignore[misc]
        return None

    # ROCm uses the C++ custom op for reshape_and_cache
    from vllm import _custom_ops as ops

    reshape_and_cache_flash = ops.reshape_and_cache_flash

54

55
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
56
57
    # import here to avoid circular dependencies
    from vllm.platforms import current_platform
58

59
60
    if current_platform.is_xpu():
        return 2
61
62
63
    if current_platform.is_rocm():
        # ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
        return None
64
65
    try:
        from vllm.vllm_flash_attn.flash_attn_interface import (
66
67
68
69
            fa_version_unsupported_reason,
            is_fa_version_supported,
        )

70
71
72
73
74
        device_capability = current_platform.get_device_capability()

        assert device_capability is not None

        # 1. default version depending on platform
75
76
77
        fa_version = (
            3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
        )
78

79
        # 2. override if passed by environment or config
80
        from vllm.config import get_current_vllm_config_or_none
81

82
83
84
85
86
        vllm_config = get_current_vllm_config_or_none()
        if (
            vllm_config is not None
            and vllm_config.attention_config.flash_attn_version is not None
        ):
87
            fa_version = vllm_config.attention_config.flash_attn_version
88
89
90

        # 3. fallback for unsupported combinations
        if device_capability.major == 10 and fa_version == 3:
91
            logger.warning_once(
92
                "Cannot use FA version 3 on Blackwell platform, "
93
94
                "defaulting to FA version 2."
            )
95
96
97
            fa_version = 2

        if requires_alibi and fa_version == 3:
98
99
100
            logger.warning_once(
                "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
            )
101
102
103
            fa_version = 2

        if not is_fa_version_supported(fa_version):
104
105
106
107
108
            logger.error(
                "Cannot use FA version %d is not supported due to %s",
                fa_version,
                fa_version_unsupported_reason(fa_version),
            )
109
110
111
112
113

        assert is_fa_version_supported(fa_version)
        return fa_version
    except (ImportError, AssertionError):
        return None
114
115
116


def flash_attn_supports_fp8() -> bool:
117
118
    return (
        get_flash_attn_version() == 3
119
        and current_platform.is_device_capability_family(90)
120
    )
121
122


123
124
125
126
127
128
129
def flash_attn_supports_sinks() -> bool:
    if current_platform.is_xpu():
        return True
    else:
        return get_flash_attn_version() == 3


130
131
def flash_attn_supports_mla():
    from vllm.platforms import current_platform
132

133
134
135
    if current_platform.is_cuda():
        try:
            from vllm.vllm_flash_attn.flash_attn_interface import (
136
137
138
                is_fa_version_supported,
            )

139
140
141
            return is_fa_version_supported(
                3
            ) and current_platform.is_device_capability_family(90)
142
143
144
145
146
        except (ImportError, AssertionError):
            pass
    return False


147
def is_flash_attn_varlen_func_available() -> bool:
148
149
150
151
152
153
154
    """Check if flash_attn_varlen_func is available.

    This function determines whether the flash_attn_varlen_func imported at module
    level is a working implementation or a stub.

    Platform-specific sources:
    - CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
155
    - XPU: xpu_ops.flash_attn_varlen_func
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    - ROCm: upstream flash_attn.flash_attn_varlen_func (if available)

    Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)
    which uses rocm_aiter_ops.flash_attn_varlen_func. The condition to use AITER is
    handled separately via _aiter_ops.is_aiter_found_and_supported().

    Returns:
        bool: True if a working flash_attn_varlen_func implementation is available.
    """
    if current_platform.is_cuda() or current_platform.is_xpu():
        # CUDA and XPU always have flash_attn_varlen_func available
        return True

    if current_platform.is_rocm():
        # Use the flag set during module import to check if
        # upstream flash-attn was successfully imported
        return _ROCM_FLASH_ATTN_AVAILABLE

    return False