flash_attn.py 3.1 KB
Newer Older
Kane's avatar
Kane committed
1
2
import math

gushiqiao's avatar
gushiqiao committed
3
4
from loguru import logger

helloyongyang's avatar
helloyongyang committed
5
try:
PengGao's avatar
PengGao committed
6
    import flash_attn  # noqa: F401
helloyongyang's avatar
helloyongyang committed
7
8
9
10
11
12
13
14
15
16
17
    from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
    logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
    flash_attn_varlen_func = None

try:
    from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
    logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
    flash_attn_varlen_func_v3 = None

Kane's avatar
Kane committed
18
19
20
21
22
23
try:
    import torch_mlu_ops as tmo
except ImportError:
    logger.info("torch_mlu_ops not found.")
    tmo = None

helloyongyang's avatar
helloyongyang committed
24
25
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER

PengGao's avatar
PengGao committed
26
27
from .template import AttnWeightTemplate

helloyongyang's avatar
helloyongyang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

@ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
    ):
45
46
47
48
        if len(q.shape) == 3:
            bs = 1
        elif len(q.shape) == 4:
            bs = q.shape[0]
helloyongyang's avatar
helloyongyang committed
49
50
51
52
53
54
55
56
        x = flash_attn_varlen_func(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
57
        ).reshape(bs * max_seqlen_q, -1)
helloyongyang's avatar
helloyongyang committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        return x


@ATTN_WEIGHT_REGISTER("flash_attn3")
class FlashAttn3Weight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
    ):
77
        if len(q.shape) == 3:
78
            bs = 1
79
        elif len(q.shape) == 4:
80
81
82
83
84
85
86
87
88
89
            bs = q.shape[0]
        x = flash_attn_varlen_func_v3(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
        ).reshape(bs * max_seqlen_q, -1)
helloyongyang's avatar
helloyongyang committed
90
        return x
Kane's avatar
Kane committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122


@ATTN_WEIGHT_REGISTER("mlu_flash_attn")
class MluFlashAttnWeight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

    def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, **kws):
        if len(q.shape) == 3:
            bs = 1
            q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
        elif len(q.shape) == 4:
            bs = q.shape[0]
        softmax_scale = 1 / math.sqrt(q.shape[-1])
        x = tmo.flash_attention(
            q=q,
            k=k,
            v=v,
            cu_seq_lens_q=cu_seqlens_q,
            cu_seq_lens_kv=cu_seqlens_kv,
            max_seq_len_q=max_seqlen_q,
            max_seq_len_kv=max_seqlen_kv,
            softmax_scale=softmax_scale,
            return_lse=False,
            out_dtype=q.dtype,
            is_causal=False,
            out=None,
            alibi_slope=None,
            attn_bias=None,
        )
        x = x.reshape(bs * max_seqlen_q, -1)
        return x