sparse_mask_generator.py 7.22 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from abc import ABC, abstractmethod

import torch
import torch.nn.functional as F
from loguru import logger

from lightx2v.utils.registry_factory import SPARSE_MASK_GENERATOR_REGISTER

from .nbhd_attn import generate_nbhd_mask
from .svg_attn import diagonal_band_mask_from_sparsity, get_attention_mask, wan_hidden_states_placement, wan_sparse_head_placement
from .utils.sla_util import get_block_map


class GeneralMaskGenerator(ABC):
    def __init__(self, q_block_size=128, k_block_size=128, sparse_setting={}, attnmap_frame_num=None):
        self.sparse_setting = sparse_setting
        self.q_block_size = q_block_size
        self.k_block_size = k_block_size
        self.attnmap_frame_num = attnmap_frame_num

    @abstractmethod
    def __call__(self, q, k):
        pass

    def reorg(self, q, k, v):
        return q, k, v

    def restore(self, out):
        return out


@SPARSE_MASK_GENERATOR_REGISTER("sla_mask_generator")
class SlaMaskGenerator(GeneralMaskGenerator):
    def __init__(self, q_block_size=128, k_block_size=128, sparse_setting={}, attnmap_frame_num=None):
        super().__init__(q_block_size, k_block_size, sparse_setting, attnmap_frame_num)
        sparsity_ratio = self.sparse_setting.get("sla_sparsity_ratio", 0.8)
        self.topk_ratio = 1 - sparsity_ratio

    def __call__(self, q, k):
        # (L, H, D) -> (B, H, L, D)
        q = q.unsqueeze(0).transpose(1, 2).contiguous()
        k = k.unsqueeze(0).transpose(1, 2).contiguous()
        sparse_map, lut, topk = get_block_map(q, k, topk_ratio=self.topk_ratio, BLKQ=self.q_block_size, BLKK=self.k_block_size)
        # return: [B, H, Q_block_num, K_block_num]
        return sparse_map


@SPARSE_MASK_GENERATOR_REGISTER("nbhd_mask_generator")
class NbhdMaskGenerator(GeneralMaskGenerator):
    seqlen = None
    mask = None

    def __init__(self, q_block_size=128, k_block_size=128, sparse_setting={}, attnmap_frame_num=None):
        super().__init__(q_block_size, k_block_size, sparse_setting, attnmap_frame_num)
        self.coefficient = self.sparse_setting.get("nbhd_coefficient", [1.0, 0.5, 0.056])
        self.min_width = self.sparse_setting.get("nbhd_min_width", 1.0)
        self.block_size = self.q_block_size

    def __call__(self, q, k):
        seqlen, head_num, head_dim = q.shape
        if seqlen == NbhdMaskGenerator.seqlen:
            return NbhdMaskGenerator.mask
        block_num = (seqlen + self.block_size - 1) // self.block_size
        block_num_per_frame = seqlen / self.attnmap_frame_num / self.block_size
        mask = generate_nbhd_mask(block_num_per_frame, block_num, self.attnmap_frame_num, coefficient=self.coefficient, min_width=self.min_width, device=q.device)
        mask = mask[None, None, :, :].repeat(1, head_num, 1, 1)
        # return: [B, H, Q_block_num, K_block_num]
        NbhdMaskGenerator.seqlen = seqlen
        NbhdMaskGenerator.mask = mask
        return mask


@SPARSE_MASK_GENERATOR_REGISTER("svg_mask_generator")
class SvgMaskGenerator(GeneralMaskGenerator):
    seqlen = None
    attention_masks = None
    mask = None

    def __init__(self, q_block_size=128, k_block_size=128, sparse_setting={}, attnmap_frame_num=None):
        super().__init__(q_block_size, k_block_size, sparse_setting, attnmap_frame_num)
        self.sample_mse_max_row = self.sparse_setting.get("svg_sample_mse_max_row", 10000)
        self.num_sampled_rows = self.sparse_setting.get("svg_num_sampled_rows", 64)
        self.context_length = self.sparse_setting.get("svg_context_length", 0)
        self.sparsity = self.sparse_setting.get("svg_sparsity", 0.75)
        self.block_size = self.k_block_size
        self.best_model_idx = None
        self.head_num = None
        self.head_dim = None

    def prepare_mask(self, q):
        seqlen, head_num, head_dim = q.shape
        if seqlen == SvgMaskGenerator.seqlen:
            return
        logger.info(f"SvgMaskGenerator: Preparing mask for seqlen={seqlen}, head_num={head_num}, head_dim={head_dim}")
        frame_size = seqlen // self.attnmap_frame_num
        SvgMaskGenerator.attention_masks = [get_attention_mask(mask_name, self.sample_mse_max_row, self.context_length, self.attnmap_frame_num, frame_size) for mask_name in ["spatial", "temporal"]]
        block_num = (seqlen + self.block_size - 1) // self.block_size
        block_num_per_frame = block_num // self.attnmap_frame_num
        mask = diagonal_band_mask_from_sparsity(block_num, block_num_per_frame, self.sparsity, device=q.device)
        SvgMaskGenerator.mask = mask[None, None, :, :].repeat(1, head_num, 1, 1)
        SvgMaskGenerator.seqlen = seqlen

    def sample_mse(self, query, key, value):
        cfg, num_heads, seq_len, dim = query.size()
        num_sampled_rows = min(self.num_sampled_rows, seq_len)
        sampled_rows = torch.randint(low=0, high=self.sample_mse_max_row, size=(num_sampled_rows,))
        sampled_q = query[:, :, sampled_rows, :]
        sampled_qk_scores = torch.matmul(sampled_q, key.transpose(-2, -1)) / (dim**0.5)

        sampled_attn_weights = F.softmax(sampled_qk_scores, dim=-1)
        sampled_golden_hidden_states = torch.matmul(sampled_attn_weights, value)  # (1, seq_len, dim)

        sampled_mses = torch.zeros(len(self.attention_masks), cfg, num_heads, device=query.device, dtype=query.dtype)

        # Only have Tri-diagonal and Striped
        for mask_idx, attn_mask in enumerate(self.attention_masks):
            sampled_attention_mask = attn_mask[sampled_rows, :]
            sampled_attention_scores = sampled_qk_scores.masked_fill(sampled_attention_mask == 0, float("-inf"))
            sampled_attn_weights = F.softmax(sampled_attention_scores, dim=-1)
            sampled_hidden_states = torch.matmul(sampled_attn_weights, value)
            mse = torch.mean((sampled_hidden_states - sampled_golden_hidden_states) ** 2, dim=(2, 3))
            sampled_mses[mask_idx] = mse

        return sampled_mses

    def reorg(self, q, k, v):
        seqlen, head_num, head_dim = q.shape

        q = q.unsqueeze(0).transpose(1, 2)
        k = k.unsqueeze(0).transpose(1, 2)
        v = v.unsqueeze(0).transpose(1, 2)

        sampled_mses = self.sample_mse(q, k, v)
        self.best_mask_idx = torch.argmin(sampled_mses, dim=0)
        self.head_num = head_num
        self.head_dim = head_dim

        q_out, k_out, v_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)

        wan_sparse_head_placement(q, k, v, q_out, k_out, v_out, self.best_mask_idx, self.context_length, self.attnmap_frame_num, seqlen // self.attnmap_frame_num)

        q_out = q_out.transpose(1, 2).squeeze(0)
        k_out = k_out.transpose(1, 2).squeeze(0)
        v_out = v_out.transpose(1, 2).squeeze(0)
        return q_out, k_out, v_out

    def restore(self, out):
        # out: (L, H*D)
        out = out.reshape(-1, self.head_num, self.head_dim)
        seqlen = out.shape[0]
        # (L, H, D) -> (B, H, L, D)
        out = out.unsqueeze(0).transpose(1, 2)

        restore_out = torch.zeros_like(out)
        wan_hidden_states_placement(out, restore_out, self.best_mask_idx, self.context_length, self.attnmap_frame_num, seqlen // self.attnmap_frame_num)

        restore_out = restore_out.transpose(1, 2).reshape(seqlen, -1)
        return restore_out

    def __call__(self, q, k):
        self.prepare_mask(q)
        # return: [B, H, Q_block_num, K_block_num]
        return self.mask