sparse_mla_topk_reducesum.py 8.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# ruff: noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
import tilelang
from tilelang import language as T
from einops import repeat, rearrange, einsum
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio

BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"

pass_configs = {
    tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}


@tilelang.jit(pass_configs=pass_configs)
def tl_sparse_mla_topk_reducesum_impl(
    heads,
    dim,
    tail_dim,
    topk,
    kv_group=1,
    sm_scale=None,
    block_I=32,
    num_stages=2,
    threads=128,
):
33
34
35
    assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
    assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
    assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
36
    if sm_scale is None:
37
        sm_scale = (1.0 / (dim + tail_dim)) ** 0.5
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    batch_plus_one = T.symbolic("batch_plus_one")
    seq_len = T.symbolic("seq_len")
    seq_len_kv = T.symbolic("seq_len_kv")

    head_kv = heads // kv_group
    indices_dtype = "int32"
    dtype = "bfloat16"
    accum_dtype = "float"

    G = kv_group
    H = head_kv
    padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
    if padded_H != H:
52
53
54
        assert kv_group == 1, (
            "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
        )
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    BI = block_I
    NI = tilelang.cdiv(topk, block_I)
    D = dim
    D_tail = tail_dim

    if head_kv > 64:
        assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
        REPLICATE_H = head_kv // 64
    else:
        REPLICATE_H = 1

    H_per_block = padded_H if REPLICATE_H == 1 else 64

    q_shape = [seq_len, heads, dim + tail_dim]
    kv_shape = [seq_len_kv, kv_group, dim + tail_dim]
    indices_shape = [seq_len, kv_group, topk]
    lse_shape = [seq_len, heads]
    reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk]
    offsets_shape = [batch_plus_one]
    token_indices_shape = [seq_len, 2]

    @T.prim_func
    def tl_sparse_mla_topk_reducesum_kernel(
78
79
80
81
82
83
84
        Q: T.Tensor(q_shape, dtype),  # type: ignore
        KV: T.Tensor(kv_shape, dtype),  # type: ignore
        Indices: T.Tensor(indices_shape, indices_dtype),  # type: ignore
        Lse: T.Tensor(lse_shape, accum_dtype),  # type: ignore
        Offsets: T.Tensor(offsets_shape, indices_dtype),  # type: ignore
        TokenIndices: T.Tensor(token_indices_shape, indices_dtype),  # type: ignore
        ReduceSum: T.Tensor(reducesum_shape, accum_dtype),  # type: ignore
85
    ):
86
87
88
89
        with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as (
            bx,
            by,
        ):
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            Q_shared = T.alloc_shared([H_per_block, D], dtype)
            Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
            KV_shared = T.alloc_shared([BI, D], dtype)
            K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
            mask = T.alloc_fragment([BI], "bool")

            acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
            reducesum = T.alloc_fragment([BI], accum_dtype)
            lse = T.alloc_fragment([H_per_block], accum_dtype)

            T.fill(lse, 0)

            b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
            b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
            bos, eos = Offsets[b_i], Offsets[b_i + 1]
            r_i = bx % REPLICATE_H
            g_i = by
            q_i = s_i
            max_kv_i = q_i

            H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
            H1 = H0 + H_per_block

            T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
            T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
            T.copy(Lse[bos + s_i, H0:H1], lse)

            for i_i in T.Pipelined(NI, num_stages=num_stages):
                for bi_i in T.Parallel(BI):
119
                    mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
120
121

                for bi_i, d_i in T.Parallel(BI, D):
122
                    KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
123
                for bi_i, d_i in T.Parallel(BI, D_tail):
124
                    K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

                for h_i, bi_i in T.Parallel(H_per_block, BI):
                    acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
                T.gemm(
                    Q_shared,
                    KV_shared,
                    acc_s,
                    transpose_B=True,
                    policy=T.GemmWarpPolicy.FullRow,
                )
                T.gemm(
                    Q_tail_shared,
                    K_tail_shared,
                    acc_s,
                    transpose_B=True,
                    policy=T.GemmWarpPolicy.FullRow,
                )
                for h_i, bi_i in T.Parallel(H_per_block, BI):
                    acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i])
                T.reduce_sum(acc_s, reducesum, dim=0)
145
                T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI])
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    return tl_sparse_mla_topk_reducesum_kernel


def sparse_mla_topk_reducesum_interface(
    q: torch.Tensor,
    kv: torch.Tensor,
    topk_indices: torch.Tensor,
    lse: torch.Tensor,
    offsets: torch.Tensor,
    dim_v: int,
):
    assert kv.shape[-2] == 1
    seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1]
    REPLICATE_H = max(heads // 64, 1)
    tail_dim = dim_plus_tail_dim - dim_v
    token_indices = prepare_token_indices(offsets)

    reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device)
    kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk)
    kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum)
    reducesum = reducesum.sum(dim=-2)  # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk]
    attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True)

    return attn_score


173
def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor):
174
175
    # q: [batch, seq_len, heads, dim]
    # k: [batch, seq_len, dim]
176
    sm_scale = Q.shape[-1] ** -0.5
177
178
179
    all_lse = []
    all_topk_score = []
    for i in range(offsets.shape[0] - 1):
180
181
182
        q = Q[offsets[i] : offsets[i + 1]]
        k = K[offsets[i] : offsets[i + 1]]
        topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
183
        seq_len = q.shape[0]
184
185
186
        mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda()
        logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale
        logits = torch.where(mask, logits, float("-inf"))
187
188
189
190
191
        score = F.softmax(logits, dim=-1, dtype=torch.float32)
        score_sum = score.sum(dim=-2)
        topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64))
        topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True)
        max_logits = logits.amax(dim=-1).to(torch.float32)
192
        lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        all_lse.append(lse)
        all_topk_score.append(topk_score)
    lse = torch.cat(all_lse, dim=0)
    topk_score = torch.cat(all_topk_score, dim=0)
    return lse, topk_score


def test_kernel(
    B=1,
    S=2048,
    H=16,
    D=512,
    tail_D=64,
    topk=128,
):
    torch.manual_seed(42)

    q = torch.randn((S, H, D + tail_D)).cuda().bfloat16()
    kv = torch.randn((S, D + tail_D)).cuda().bfloat16()
    offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()

214
    topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
215
216
217
218
219
220

    lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets)

    kv = kv.unsqueeze(-2)
    topk_indices = topk_indices.unsqueeze(-2)

221
222
    attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2)
    print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}")
223
224


225
if __name__ == "__main__":
226
    test_kernel()