sparge_attn.py 1.73 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
import torch
import torch.nn as nn
PengGao's avatar
PengGao committed
3
4
5
6
7
from loguru import logger

from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER

from .template import AttnWeightTemplate
helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

try:
    from spas_sage_attn.autotune import SparseAttentionMeansim
except ImportError:
    logger.info("SparseAttentionMeansim not found, please install sparge first")
    SparseAttentionMeansim = None


@ATTN_WEIGHT_REGISTER("Sparge")
class SpargeAttnWeight(AttnWeightTemplate):
    def __init__(
        self,
        weight_name,
        verbose=False,
        l1=0.07,
        pv_l1=0.08,
        tune_pv=True,
        inner_attn_type="flash_attn3",
    ):
        self.verbose = (verbose,)
        self.l1 = (l1,)
        self.pv_l1 = (pv_l1,)
        self.tune_pv = (tune_pv,)
        self.inner_attn_type = inner_attn_type
        self.inner_cls = SparseAttentionMeansim(l1=l1, pv_l1=pv_l1, tune_pv=tune_pv)
        super().__init__(weight_name)

    def load(self, weight_dict):
        # match all key with prefix weight_name
        for key in weight_dict.keys():
            if key.startswith(self.weight_name):
                sub_name = key.split(".")[-1]
                setattr(
                    self.inner_cls,
                    sub_name,
                    nn.Parameter(weight_dict[key], requires_grad=False),
                )

    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
    ):
        if len(q.shape) == 3:
            q = q.unsqueeze(0)
            k = k.unsqueeze(0)
            v = v.unsqueeze(0)

        x = self.inner_cls(q, k, v, tensor_layout="NHD")
        x = x.flatten(2)
        x = x.squeeze(0)

        return x