sparse_mla_topk_reducesum.py 8.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# ruff: noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
import tilelang
from tilelang import language as T
from einops import repeat, rearrange, einsum
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio

BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"

pass_configs = {
    tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}


@tilelang.jit(pass_configs=pass_configs)
def tl_sparse_mla_topk_reducesum_impl(
    heads,
    dim,
    tail_dim,
    topk,
    kv_group=1,
    sm_scale=None,
    block_I=32,
    num_stages=2,
    threads=128,
):
    assert dim == tilelang.math.next_power_of_2(
        dim), f"haven't check padding correctness yet, dim={dim}"
    assert tail_dim == tilelang.math.next_power_of_2(
        tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
    assert (topk %
            block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded"
    if sm_scale is None:
        sm_scale = (1.0 / (dim + tail_dim))**0.5

    batch_plus_one = T.symbolic("batch_plus_one")
    seq_len = T.symbolic("seq_len")
    seq_len_kv = T.symbolic("seq_len_kv")

    head_kv = heads // kv_group
    indices_dtype = "int32"
    dtype = "bfloat16"
    accum_dtype = "float"

    G = kv_group
    H = head_kv
    padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
    if padded_H != H:
        assert (
            kv_group == 1
        ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
    BI = block_I
    NI = tilelang.cdiv(topk, block_I)
    D = dim
    D_tail = tail_dim

    if head_kv > 64:
        assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
        REPLICATE_H = head_kv // 64
    else:
        REPLICATE_H = 1

    H_per_block = padded_H if REPLICATE_H == 1 else 64

    q_shape = [seq_len, heads, dim + tail_dim]
    kv_shape = [seq_len_kv, kv_group, dim + tail_dim]
    indices_shape = [seq_len, kv_group, topk]
    lse_shape = [seq_len, heads]
    reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk]
    offsets_shape = [batch_plus_one]
    token_indices_shape = [seq_len, 2]

    @T.prim_func
    def tl_sparse_mla_topk_reducesum_kernel(
            Q: T.Tensor(q_shape, dtype),  # type: ignore
            KV: T.Tensor(kv_shape, dtype),  # type: ignore
            Indices: T.Tensor(indices_shape, indices_dtype),  # type: ignore
            Lse: T.Tensor(lse_shape, accum_dtype),  # type: ignore
            Offsets: T.Tensor(offsets_shape, indices_dtype),  # type: ignore
            TokenIndices: T.Tensor(token_indices_shape, indices_dtype),  # type: ignore
            ReduceSum: T.Tensor(reducesum_shape, accum_dtype),  # type: ignore
    ):
        with T.Kernel(
                seq_len * REPLICATE_H, kv_group, threads=threads) as (
                    bx,
                    by,
                ):
            Q_shared = T.alloc_shared([H_per_block, D], dtype)
            Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
            KV_shared = T.alloc_shared([BI, D], dtype)
            K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
            mask = T.alloc_fragment([BI], "bool")

            acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
            reducesum = T.alloc_fragment([BI], accum_dtype)
            lse = T.alloc_fragment([H_per_block], accum_dtype)

            T.fill(lse, 0)

            b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
            b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
            bos, eos = Offsets[b_i], Offsets[b_i + 1]
            r_i = bx % REPLICATE_H
            g_i = by
            q_i = s_i
            max_kv_i = q_i

            H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
            H1 = H0 + H_per_block

            T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
            T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
            T.copy(Lse[bos + s_i, H0:H1], lse)

            for i_i in T.Pipelined(NI, num_stages=num_stages):

                for bi_i in T.Parallel(BI):
                    mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (
                        Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)

                for bi_i, d_i in T.Parallel(BI, D):
                    KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i,
                                              d_i]
                for bi_i, d_i in T.Parallel(BI, D_tail):
                    K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i],
                                                  g_i, D + d_i]

                for h_i, bi_i in T.Parallel(H_per_block, BI):
                    acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
                T.gemm(
                    Q_shared,
                    KV_shared,
                    acc_s,
                    transpose_B=True,
                    policy=T.GemmWarpPolicy.FullRow,
                )
                T.gemm(
                    Q_tail_shared,
                    K_tail_shared,
                    acc_s,
                    transpose_B=True,
                    policy=T.GemmWarpPolicy.FullRow,
                )
                for h_i, bi_i in T.Parallel(H_per_block, BI):
                    acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i])
                T.reduce_sum(acc_s, reducesum, dim=0)
                T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI:i_i * BI + BI])

    return tl_sparse_mla_topk_reducesum_kernel


def sparse_mla_topk_reducesum_interface(
    q: torch.Tensor,
    kv: torch.Tensor,
    topk_indices: torch.Tensor,
    lse: torch.Tensor,
    offsets: torch.Tensor,
    dim_v: int,
):
    assert kv.shape[-2] == 1
    seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1]
    REPLICATE_H = max(heads // 64, 1)
    tail_dim = dim_plus_tail_dim - dim_v
    token_indices = prepare_token_indices(offsets)

    reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device)
    kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk)
    kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum)
    reducesum = reducesum.sum(dim=-2)  # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk]
    attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True)

    return attn_score


def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor,
                         offsets: torch.Tensor):
    # q: [batch, seq_len, heads, dim]
    # k: [batch, seq_len, dim]
    sm_scale = Q.shape[-1]**-0.5
    all_lse = []
    all_topk_score = []
    for i in range(offsets.shape[0] - 1):
        q = Q[offsets[i]:offsets[i + 1]]
        k = K[offsets[i]:offsets[i + 1]]
        topk_indices = TopkIndices[offsets[i]:offsets[i + 1]]
        seq_len = q.shape[0]
        mask = (torch.arange(seq_len)[:, None]
                >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda()
        logits = einsum(q, k, 's1 h d, s2 d -> s1 h s2') * sm_scale
        logits = torch.where(mask, logits, float('-inf'))
        score = F.softmax(logits, dim=-1, dtype=torch.float32)
        score_sum = score.sum(dim=-2)
        topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64))
        topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True)
        max_logits = logits.amax(dim=-1).to(torch.float32)
        lse = torch.log(
            (logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits
        all_lse.append(lse)
        all_topk_score.append(topk_score)
    lse = torch.cat(all_lse, dim=0)
    topk_score = torch.cat(all_topk_score, dim=0)
    return lse, topk_score


def test_kernel(
    B=1,
    S=2048,
    H=16,
    D=512,
    tail_D=64,
    topk=128,
):
    torch.manual_seed(42)

    q = torch.randn((S, H, D + tail_D)).cuda().bfloat16()
    kv = torch.randn((S, D + tail_D)).cuda().bfloat16()
    offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()

    topk_indices = repeat(
        torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous()

    lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets)

    kv = kv.unsqueeze(-2)
    topk_indices = topk_indices.unsqueeze(-2)

    attn_score = sparse_mla_topk_reducesum_interface(
        q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2)
    print(
        f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}"
    )


if __name__ == '__main__':
    test_kernel()