nbhd_attn.py 4.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from loguru import logger

try:
    from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
    magi_ffa_func = None

from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER

from .template import AttnWeightTemplate


14
def generate_nbhd_mask(a, block_num, attnmap_frame_num, coefficient=[1.0, 0.5, 0.056], min_width=1.0, device="cpu"):
15
16
17
    """
    a : block num per frame
    block_num : block num per col/row
18
    attnmap_frame_num : total frame num
19
20
21
22
    """
    i_indices = torch.arange(block_num, device=device).unsqueeze(1)  # [block_num, 1]
    j_indices = torch.arange(block_num, device=device).unsqueeze(0)  # [1, block_num]

23
24
25
26
27
    assert len(coefficient) <= attnmap_frame_num, f"coefficient length {len(coefficient)} should <= attnmap_frame_num {attnmap_frame_num}"
    width_list = [max(min_width, coefficient[i] * a) for i in range(len(coefficient))] + [min_width] * (attnmap_frame_num - len(coefficient))
    logger.info(f"nbhd_attn width_list: {width_list}, len={len(width_list)}")

    # attention sink frame: j <= a
28
29
    mask_sink = j_indices <= a

30
31
32
33
34
35
    mask_sparse = torch.zeros((block_num, block_num), dtype=torch.bool, device=device)
    for interval in range(0, attnmap_frame_num):
        n = i_indices // a
        mask_sparse_base_1 = (j_indices >= (n + interval) * a) & (j_indices <= (n + interval + 1) * a)
        n = j_indices // a
        mask_sparse_base_2 = (i_indices >= (n + interval) * a) & (i_indices <= (n + interval + 1) * a)
36

37
        width = width_list[interval]
38

39
40
        mask_1 = mask_sparse_base_1 & (i_indices - j_indices + (interval * a + width) >= 0) & (i_indices - j_indices + (interval * a - width) <= 0)
        mask_2 = mask_sparse_base_2 & (i_indices - j_indices - (interval * a - width) >= 0) & (i_indices - j_indices - (interval * a + width) <= 0)
41

42
        mask_sparse = mask_sparse | mask_1 | mask_2
43

44
    mask = mask_sink | mask_sparse
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    return mask


def generate_qk_ranges(mask, block_size, seqlen):
    indices = torch.nonzero(mask, as_tuple=False)  # shape: [N, 2]

    i_indices = indices[:, 0]  # [N]
    j_indices = indices[:, 1]  # [N]

    q_start = i_indices * block_size  # [N]
    q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen)  # [N]

    k_start = j_indices * block_size  # [N]
    k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen)  # [N]

    q_ranges = torch.stack([q_start, q_end], dim=1)  # [N, 2]
    k_ranges = torch.stack([k_start, k_end], dim=1)  # [N, 2]

    return q_ranges, k_ranges


@ATTN_WEIGHT_REGISTER("nbhd_attn")
class NbhdAttnWeight(AttnWeightTemplate):
    block_size = 128
    seqlen = None
70
    attnmap_frame_num = None
71
72
73
    q_ranges = None
    k_ranges = None
    attn_type_map = None
74
75
    coefficient = [1.0, 0.5, 0.056]
    min_width = 1.0
76
77
78
79
80

    def __init__(self):
        self.config = {}

    @classmethod
81
82
    def prepare_mask(cls, seqlen):
        if seqlen == cls.seqlen:
83
84
            return
        block_num = (seqlen + cls.block_size - 1) // cls.block_size
85
86
        block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size
        mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu")
87
88
89
90
91
92
93
94
        q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
        attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
        q_ranges = q_ranges.to(torch.int32).to("cuda")
        k_ranges = k_ranges.to(torch.int32).to("cuda")
        cls.seqlen = seqlen
        cls.q_ranges = q_ranges
        cls.k_ranges = k_ranges
        cls.attn_type_map = attn_type_map
95
        logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        sparsity = 1 - mask.sum().item() / mask.numel()
        logger.info(f"Attention sparsity: {sparsity}")

    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
    ):
        """
        q: [seqlen, head_num, head_dim]
        k: [seqlen, head_num, head_dim]
        v: [seqlen, head_num, head_dim]
        """
115
        self.prepare_mask(seqlen=q.shape[0])
116
117
118
119
120
121
122
123
124
125
        out = magi_ffa_func(
            q,
            k,
            v,
            q_ranges=self.q_ranges,
            k_ranges=self.k_ranges,
            attn_type_map=self.attn_type_map,
            auto_range_merge=True,
        )[0]
        return out.reshape(out.shape[0], -1)