attn_weight.py 7.56 KB
Newer Older
Dongz's avatar
Dongz committed
1
2
3
4
5
import torch
import torch.nn as nn
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
import torch.nn.functional as F
6
from loguru import logger
Dongz's avatar
Dongz committed
7
8
9
10

try:
    from spas_sage_attn.autotune import SparseAttentionMeansim
except ImportError:
11
    logger.info("SparseAttentionMeansim not found, please install sparge first")
Dongz's avatar
Dongz committed
12
13
14
15
16
    SparseAttentionMeansim = None

try:
    from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
17
    logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
Dongz's avatar
Dongz committed
18
19
20
21
22
    flash_attn_varlen_func = None

try:
    from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
23
    logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
Dongz's avatar
Dongz committed
24
25
    flash_attn_varlen_func_v3 = None

26
if torch.cuda.get_device_capability(0)[0] <= 8 and torch.cuda.get_device_capability(0)[1] <= 9:
Dongz's avatar
Dongz committed
27
28
29
    try:
        from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
    except ImportError:
30
        logger.info("sageattn not found, please install sageattention first")
31
        sageattn = None
Dongz's avatar
Dongz committed
32
33
34
35
else:
    try:
        from sageattention import sageattn
    except ImportError:
36
        logger.info("sageattn not found, please install sageattention first")
Dongz's avatar
Dongz committed
37
38
39
        sageattn = None


wangshankun's avatar
wangshankun committed
40
41
42
from lightx2v.attentions.common.radial_attn import radial_attn


Dongz's avatar
Dongz committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class AttnWeightTemplate(metaclass=ABCMeta):
    def __init__(self, weight_name):
        self.weight_name = weight_name
        self.config = {}

    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass

    def set_config(self, config=None):
        if config is not None:
            self.config = config

    def to_cpu(self, non_blocking=False):
TorynCurtis's avatar
TorynCurtis committed
60
        pass
Dongz's avatar
Dongz committed
61
62

    def to_cuda(self, non_blocking=False):
TorynCurtis's avatar
TorynCurtis committed
63
        pass
Dongz's avatar
Dongz committed
64

helloyongyang's avatar
helloyongyang committed
65
66
67
68
69
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        return destination

Dongz's avatar
Dongz committed
70
71
72
73
74
75

@ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

gushiqiao's avatar
gushiqiao committed
76
77
78
79
80
81
82
83
84
85
86
87
    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
        mask_map=None,
    ):
Dongz's avatar
Dongz committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        x = flash_attn_varlen_func(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
        ).reshape(max_seqlen_q, -1)
        return x


@ATTN_WEIGHT_REGISTER("flash_attn3")
class FlashAttn3Weight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

gushiqiao's avatar
gushiqiao committed
105
106
107
108
109
110
111
112
113
114
115
116
    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
        mask_map=None,
    ):
Dongz's avatar
Dongz committed
117
118
119
120
121
122
123
124
125
        x = flash_attn_varlen_func_v3(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
        )[0].reshape(max_seqlen_q, -1)
wangshankun's avatar
wangshankun committed
126
127
128
129
130
131
132
133
        return x


@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

gushiqiao's avatar
gushiqiao committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        mask_map=None,
        sparsity_type="radial",
        block_size=128,
        decay_factor=1,
        model_cls="wan",
    ):
wangshankun's avatar
wangshankun committed
149
150
151
152
153
154
155
156
157
158
159
160
161
        assert len(q.shape) == 3

        x = radial_attn(
            q,
            k,
            v,
            mask_map=mask_map,
            sparsity_type=sparsity_type,
            block_size=block_size,
            model_cls=model_cls[:3],  # Use first 3 characters to match "wan", "wan2", etc.
            decay_factor=decay_factor,
        )
        x = x.view(max_seqlen_q, -1)
Dongz's avatar
Dongz committed
162
163
164
165
166
167
168
169
        return x


@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

gushiqiao's avatar
gushiqiao committed
170
171
172
173
174
175
176
177
178
179
180
181
    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
        mask_map=None,
    ):
Dongz's avatar
Dongz committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
        if model_cls == "hunyuan":
            x1 = sageattn(
                q[: cu_seqlens_q[1]].unsqueeze(0),
                k[: cu_seqlens_kv[1]].unsqueeze(0),
                v[: cu_seqlens_kv[1]].unsqueeze(0),
                tensor_layout="NHD",
            )
            x2 = sageattn(
                q[cu_seqlens_q[1] :].unsqueeze(0),
                k[cu_seqlens_kv[1] :].unsqueeze(0),
                v[cu_seqlens_kv[1] :].unsqueeze(0),
                tensor_layout="NHD",
            )
            x = torch.cat((x1, x2), dim=1)
            x = x.view(max_seqlen_q, -1)
198
        elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df"]:
Dongz's avatar
Dongz committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            x = sageattn(
                q.unsqueeze(0),
                k.unsqueeze(0),
                v.unsqueeze(0),
                tensor_layout="NHD",
            )
            x = x.view(max_seqlen_q, -1)
        return x


@ATTN_WEIGHT_REGISTER("torch_sdpa")
class TorchSDPAWeight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

gushiqiao's avatar
gushiqiao committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    def apply(
        self,
        q,
        k,
        v,
        drop_rate=0,
        attn_mask=None,
        causal=False,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
        mask_map=None,
    ):
        q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
Dongz's avatar
Dongz committed
230
231
232
233
234
235
236
237
238
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        if attn_mask is not None and attn_mask.dtype != torch.bool:
            attn_mask = attn_mask.to(q.dtype)
        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
        x = x.transpose(1, 2)
        b, s, a, d = x.shape
        out = x.reshape(b, s, -1)
gushiqiao's avatar
gushiqiao committed
239
        return out.squeeze(0)
Dongz's avatar
Dongz committed
240
241
242
243


@ATTN_WEIGHT_REGISTER("Sparge")
class SpargeAttnWeight(AttnWeightTemplate):
gushiqiao's avatar
gushiqiao committed
244
245
246
247
248
249
250
251
252
    def __init__(
        self,
        weight_name,
        verbose=False,
        l1=0.07,
        pv_l1=0.08,
        tune_pv=True,
        inner_attn_type="flash_attn3",
    ):
Dongz's avatar
Dongz committed
253
254
255
256
257
258
259
260
261
262
263
264
265
        self.verbose = (verbose,)
        self.l1 = (l1,)
        self.pv_l1 = (pv_l1,)
        self.tune_pv = (tune_pv,)
        self.inner_attn_type = inner_attn_type
        self.inner_cls = SparseAttentionMeansim(l1=l1, pv_l1=pv_l1, tune_pv=tune_pv)
        super().__init__(weight_name)

    def load(self, weight_dict):
        # match all key with prefix weight_name
        for key in weight_dict.keys():
            if key.startswith(self.weight_name):
                sub_name = key.split(".")[-1]
gushiqiao's avatar
gushiqiao committed
266
267
268
269
270
                setattr(
                    self.inner_cls,
                    sub_name,
                    nn.Parameter(weight_dict[key], requires_grad=False),
                )
Dongz's avatar
Dongz committed
271

gushiqiao's avatar
gushiqiao committed
272
273
274
275
276
277
278
279
280
281
282
    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
    ):
Dongz's avatar
Dongz committed
283
284
285
286
287
288
289
290
291
292
        if len(q.shape) == 3:
            q = q.unsqueeze(0)
            k = k.unsqueeze(0)
            v = v.unsqueeze(0)

        x = self.inner_cls(q, k, v, tensor_layout="NHD")
        x = x.flatten(2)
        x = x.squeeze(0)

        return x