sparse_operator.py 6.58 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch

from lightx2v.utils.registry_factory import SPARSE_OPERATOR_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE

from .kernels.sla_kernel import _attention

try:
    from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
    magi_ffa_func = None

try:
    from flex_block_attn import flex_block_attn_func
except ImportError:
    flex_block_attn_func = None

try:
    import flashinfer
except ImportError:
    flashinfer = None


@SPARSE_OPERATOR_REGISTER("sla_triton_operator")
class SlaTritonOperator:
    def __init__(self):
        self.q_block_size = 128
        self.k_block_size = 128

    def __call__(
        self,
        q,
        k,
        v,
        mask,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        **kwargs,
    ):
        # (L, H, D) -> (B, H, L, D)
        q = q.unsqueeze(0).transpose(1, 2).contiguous()
        k = k.unsqueeze(0).transpose(1, 2).contiguous()
        v = v.unsqueeze(0).transpose(1, 2).contiguous()

        # (B, H, Q_block_num, K_block_num)
        mask = mask.int()
        topk = int(mask.sum(dim=-1).max().item())
        lut = torch.topk(mask, topk, dim=-1, sorted=False).indices

        out = _attention.apply(q, k, v, mask, lut, topk, self.q_block_size, self.k_block_size)
        out = out.transpose(1, 2).reshape(max_seqlen_q, -1)
        return out


@SPARSE_OPERATOR_REGISTER("magi_operator")
class MagiOperator:
    def __init__(self):
        self.q_block_size = 128
        self.k_block_size = 128

    def generate_qk_ranges(self, mask, q_block_size, k_block_size, seqlen):
        # mask: [H, Q_block_num, K_block_num]
        h_indices, i_indices, j_indices = torch.nonzero(mask, as_tuple=True)

        base_offset = h_indices * seqlen

        q_start = base_offset + i_indices * q_block_size
        q_end = base_offset + torch.clamp((i_indices + 1) * q_block_size, max=seqlen)

        k_start = base_offset + j_indices * k_block_size
        k_end = base_offset + torch.clamp((j_indices + 1) * k_block_size, max=seqlen)

        q_ranges = torch.stack([q_start, q_end], dim=1).to(dtype=torch.int32)
        k_ranges = torch.stack([k_start, k_end], dim=1).to(dtype=torch.int32)

        return q_ranges, k_ranges

    def __call__(
        self,
        q,
        k,
        v,
        mask,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        **kwargs,
    ):
        seqlen, head_num, head_dim = q.shape
        # (B, H, Q_block_num, K_block_num) -> (H, Q_block_num, K_block_num)
        mask = mask.squeeze(0)
        q_ranges, k_ranges = self.generate_qk_ranges(mask, self.q_block_size, self.k_block_size, seqlen)
        attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cpu").to(q.device, non_blocking=True)

        q = q.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
        k = k.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
        v = v.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)

        out = magi_ffa_func(
            q,
            k,
            v,
            q_ranges=q_ranges,
            k_ranges=k_ranges,
            attn_type_map=attn_type_map,
            auto_range_merge=True,
        )[0]

        out = out.reshape(head_num, seqlen, head_dim).permute(1, 0, 2)

        return out.reshape(out.shape[0], -1)


@SPARSE_OPERATOR_REGISTER("flex_block_operator")
class FlexBlockOperator:
    def __init__(self):
        self.q_block_size = 128
        self.k_block_size = 128

    def __call__(
        self,
        q,
        k,
        v,
        mask,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        **kwargs,
    ):
        q = q.unsqueeze(0).transpose(1, 2)
        k = k.unsqueeze(0).transpose(1, 2)
        v = v.unsqueeze(0).transpose(1, 2)

        pad_len = (self.q_block_size - q.shape[2] % self.q_block_size) % self.q_block_size
        if pad_len > 0:
            q = torch.nn.functional.pad(q, (0, 0, 0, pad_len))
            k = torch.nn.functional.pad(k, (0, 0, 0, pad_len))
            v = torch.nn.functional.pad(v, (0, 0, 0, pad_len))

        # (B, H, Q_block_num, K_block_num)
        mask = mask.bool()
        out = flex_block_attn_func(q, k, v, self.q_block_size, self.k_block_size, mask)

        if pad_len > 0:
            out = out[:, :, :-pad_len, :]

        out = out.transpose(1, 2)

        return out.reshape(max_seqlen_q, -1)


@SPARSE_OPERATOR_REGISTER("flashinfer_operator")
class FlashinferOperator:
    sparse_wrapper = None
    mask = None

    def __init__(self):
        self.q_block_size = 128
        self.k_block_size = 128
        if FlashinferOperator.sparse_wrapper is None:
            float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device=AI_DEVICE)
            FlashinferOperator.sparse_wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="fa2")

    def __call__(
        self,
        q,
        k,
        v,
        mask,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        **kwargs,
    ):
        seqlen, head_num, head_dim = q.shape
        # (B, H, Q_block_num, K_block_num) -> (H, Q_block_num, K_block_num)
        mask = mask.squeeze(0)
        if FlashinferOperator.mask is None or not torch.equal(mask, FlashinferOperator.mask):
            _, q_block_num, k_block_num = mask.shape
            block_row_sz = torch.ones(q_block_num, dtype=torch.int32, device=q.device) * self.q_block_size
            block_row_sz[-1] = seqlen - self.q_block_size * (q_block_num - 1)
            block_row_sz = block_row_sz.unsqueeze(0).repeat(head_num, 1)
            block_col_sz = torch.ones(k_block_num, dtype=torch.int32, device=k.device) * self.k_block_size
            block_col_sz[-1] = seqlen - self.k_block_size * (k_block_num - 1)
            block_col_sz = block_col_sz.unsqueeze(0).repeat(head_num, 1)
            FlashinferOperator.sparse_wrapper.plan(
                block_mask_map=mask,
                block_row_sz=block_row_sz,
                block_col_sz=block_col_sz,
                num_qo_heads=head_num,
                num_kv_heads=head_num,
                head_dim=head_dim,
                q_data_type=q.dtype,
            )
            FlashinferOperator.mask = mask

        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)
        out = FlashinferOperator.sparse_wrapper.run(q, k, v)
        out = out.transpose(0, 1)

        return out.reshape(max_seqlen_q, -1)