fa_utils.py 3.32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

from vllm import envs
from vllm.logger import init_logger
6
from vllm.platforms import current_platform
7
8
9

logger = init_logger(__name__)

10
11
if current_platform.is_cuda():
    from vllm import _custom_ops as ops
12

13
    reshape_and_cache_flash = ops.reshape_and_cache_flash
14
    from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
15
16
elif current_platform.is_xpu():
    from vllm._ipex_ops import ipex_ops as ops
17

18
19
20
21
    reshape_and_cache_flash = ops.reshape_and_cache_flash
    flash_attn_varlen_func = ops.flash_attn_varlen_func
    get_scheduler_metadata = ops.get_scheduler_metadata

22

23
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
24
25
    # import here to avoid circular dependencies
    from vllm.platforms import current_platform
26

27
28
    if current_platform.is_xpu():
        return 2
29
30
    try:
        from vllm.vllm_flash_attn.flash_attn_interface import (
31
32
33
34
            fa_version_unsupported_reason,
            is_fa_version_supported,
        )

35
36
37
38
39
        device_capability = current_platform.get_device_capability()

        assert device_capability is not None

        # 1. default version depending on platform
40
41
42
        fa_version = (
            3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
        )
43
44
45
46
47
48
49
50

        # 2. override if passed by environment
        if envs.VLLM_FLASH_ATTN_VERSION is not None:
            assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
            fa_version = envs.VLLM_FLASH_ATTN_VERSION

        # 3. fallback for unsupported combinations
        if device_capability.major == 10 and fa_version == 3:
51
52
            logger.warning_once(
                "Cannot use FA version 3 on Blackwell platform "
53
54
                "defaulting to FA version 2."
            )
55
56
57
            fa_version = 2

        if requires_alibi and fa_version == 3:
58
59
60
            logger.warning_once(
                "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
            )
61
62
63
            fa_version = 2

        if not is_fa_version_supported(fa_version):
64
65
66
67
68
            logger.error(
                "Cannot use FA version %d is not supported due to %s",
                fa_version,
                fa_version_unsupported_reason(fa_version),
            )
69
70
71
72
73

        assert is_fa_version_supported(fa_version)
        return fa_version
    except (ImportError, AssertionError):
        return None
74
75
76


def flash_attn_supports_fp8() -> bool:
77
78
79
80
    return (
        get_flash_attn_version() == 3
        and current_platform.get_device_capability().major == 9
    )
81
82


83
84
85
86
87
88
89
def flash_attn_supports_sinks() -> bool:
    if current_platform.is_xpu():
        return True
    else:
        return get_flash_attn_version() == 3


90
91
def flash_attn_supports_mla():
    from vllm.platforms import current_platform
92

93
94
95
    if current_platform.is_cuda():
        try:
            from vllm.vllm_flash_attn.flash_attn_interface import (
96
97
98
99
100
                is_fa_version_supported,
            )

            return (
                is_fa_version_supported(3)
101
                and current_platform.get_device_capability()[0] == 9
102
            )
103
104
105
106
107
        except (ImportError, AssertionError):
            pass
    return False


108
109
def is_flash_attn_varlen_func_available() -> bool:
    return current_platform.is_cuda() or current_platform.is_xpu()