ring_globals.py 2.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention

# test if flash_attn (FA2) is available
try:
    import flash_attn  # noqa: F401
    from flash_attn.flash_attn_interface import _flash_attn_forward  # noqa: F401

    HAS_FLASH_ATTN = True
except (ImportError, ModuleNotFoundError):
    HAS_FLASH_ATTN = False

# FA3 detection: try multiple sources (forward only, no backward needed for inference)
# Source 1: flash_attn_interface (from flash-attention source build)
# Source 2: fa3_fwd_interface (from fa3-fwd PyPI package, supports Ampere/Ada/Hopper)
# Note: FA3 high-level API may or may not return softmax_lse depending on version.
#       For Ring Attention which requires LSE, we fall back to low-level API if needed.
HAS_FA3 = False
fa3_fwd_func = None  # Low-level forward function (_flash_attn_forward)
fa3_attn_func = None  # High-level attention function (flash_attn_func)

# Try flash_attn_interface first (from flash-attention source build)
try:
    from flash_attn_interface import _flash_attn_forward as fa3_fwd_func  # noqa: F401
    from flash_attn_interface import flash_attn_func as fa3_attn_func  # noqa: F401

    HAS_FA3 = True
except (ImportError, ModuleNotFoundError):
    pass

# Fallback: try fa3_fwd_interface (PyPI package, supports Ampere/Ada/Hopper)
if not HAS_FA3:
    try:
        from fa3_fwd_interface import _flash_attn_forward as fa3_fwd_func  # noqa: F401
        from fa3_fwd_interface import flash_attn_func as fa3_attn_func  # noqa: F401

        HAS_FA3 = True
    except (ImportError, ModuleNotFoundError):
        pass

# Legacy aliases for backward compatibility
HAS_FLASH_ATTN_HOPPER = HAS_FA3
flash_attn_forward_hopper = fa3_fwd_func
flash3_attn_func = fa3_attn_func

try:
    from flashinfer.prefill import single_prefill_with_kv_cache  # noqa: F401

    HAS_FLASHINFER = True
except (ImportError, ModuleNotFoundError):
    HAS_FLASHINFER = False

try:
    import aiter  # noqa: F401
    from aiter import flash_attn_func as flash_attn_func_aiter  # noqa: F401

    HAS_AITER = True
except (ImportError, ModuleNotFoundError):
    HAS_AITER = False

try:
    import sageattention  # noqa: F401

    HAS_SAGE_ATTENTION = True
except (ImportError, ModuleNotFoundError):
    HAS_SAGE_ATTENTION = False

try:
    import spas_sage_attn  # noqa: F401

    HAS_SPARSE_SAGE_ATTENTION = True
except (ImportError, ModuleNotFoundError):
    HAS_SPARSE_SAGE_ATTENTION = False

try:
    import torch_npu  # noqa: F401

    HAS_NPU = True
except (ImportError, ModuleNotFoundError):
    HAS_NPU = False