sparse_mla_bwd.py 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# ruff: noqa
import tilelang
from tilelang import language as T
import torch
from index import prepare_token_indices

from utils import assert_tensors_similar


@tilelang.jit(out_idx=[-1])
def preprocess(
    H,
    D,
    block_ND=32,
    num_stages=5,
16
17
    dtype=T.bfloat16,
    accum_dtype=T.float32,
18
):
19
20
    assert dtype == T.bfloat16
    assert accum_dtype == T.float32
21

22
    S = T.symbolic("S")
23
24
25
26
27

    shape = [S, H, D]

    @T.prim_func
    def preprocess_kernel(
28
29
30
        O: T.Tensor(shape, dtype),
        dO: T.Tensor(shape, dtype),
        Delta: T.Tensor([S, H], accum_dtype),
31
32
33
34
35
36
37
38
    ):
        with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by):
            o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
            do = T.alloc_fragment([block_ND, block_ND], accum_dtype)
            delta = T.alloc_fragment([block_ND], accum_dtype)
            acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
            T.clear(acc)
            for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
39
40
                T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o)
                T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do)
41
42
43
                for i, j in T.Parallel(block_ND, block_ND):
                    acc[i, j] += o[i, j] * do[i, j]
            T.reduce_sum(acc, delta, 1)
44
            T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx])
45
46
47
48
49
50
51
52
53
54
55

    return preprocess_kernel


@tilelang.jit(out_idx=[-1])
def postprocess(
    D,
    D_tail,
    kv_group=1,
    block_N=64,
    threads=128,
56
57
    dtype=T.bfloat16,
    accum_dtype=T.float32,
58
):
59
60
    assert dtype == T.bfloat16
    assert accum_dtype == T.float32
61
    S_kv = T.symbolic("S_kv")
62
63
64
65
66

    dkv_shape = [S_kv, kv_group, D + D_tail]

    @T.prim_func
    def postprocess_kernel(
67
68
        dKV: T.Tensor(dkv_shape, accum_dtype),
        dKV_out: T.Tensor(dkv_shape, dtype),
69
70
71
    ):
        with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by):
            T.copy(
72
73
                dKV[bx * block_N : (bx + 1) * block_N, by, :],
                dKV_out[bx * block_N : (bx + 1) * block_N, by, :],
74
75
76
77
78
79
80
81
82
83
            )

    return postprocess_kernel


@tilelang.jit(
    out_idx=[-2],
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
84
85
    },
)
86
87
88
89
90
91
92
93
94
95
96
def bwd(
    H,
    D,
    D_tail,
    topk,
    kv_group=1,
    sm_scale=None,
    is_causal=True,
    block_size=32,
    num_stages=0,
    threads=128,
97
98
99
    indices_dtype=T.int32,
    dtype=T.bfloat16,
    accum_dtype=T.float32,
100
):
101
102
    assert is_causal == True, "non-casual is not supported now"
    assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
103
104
105
    assert dtype == T.bfloat16
    assert accum_dtype == T.float32
    assert indices_dtype == T.int32
106
107

    if sm_scale is None:
108
        sm_scale = (D + D_tail) ** (-0.5)
109

110
111
    B_plus_one = T.symbolic("B_plus_one")
    S = T.symbolic("S")
112
113
114
115
116
117
118
119
120
121

    H_kv = H // kv_group
    q_shape = [S, H, D + D_tail]
    k_shape = [S, kv_group, D + D_tail]
    o_shape = [S, H, D]
    indices_shape = [S, kv_group, topk]
    delta_shape = [S, H]
    lse_shape = [S, H]
    offsets_shape = [B_plus_one]
    token_indices_shape = [S, 2]
122
123
124
    assert indices_dtype == T.int32
    assert dtype == T.bfloat16
    assert accum_dtype == T.float32
125
126
127
128
129
130
131
132
133
134

    H = H_kv
    padded_H = max(tilelang.math.next_power_of_2(H_kv), 16)
    BS = block_size
    NS = tilelang.cdiv(topk, block_size)

    split_store = 2

    @T.prim_func
    def sparse_mla_bwd_kernel(
135
136
137
138
139
140
141
142
143
144
        Q: T.Tensor(q_shape, dtype),
        KV: T.Tensor(k_shape, dtype),
        dO: T.Tensor(o_shape, dtype),
        Indices: T.Tensor(indices_shape, indices_dtype),
        Lse: T.Tensor(lse_shape, accum_dtype),
        Delta: T.Tensor(delta_shape, accum_dtype),
        Offsets: T.Tensor(offsets_shape, indices_dtype),
        TokenIndices: T.Tensor(token_indices_shape, indices_dtype),
        dQ: T.Tensor(q_shape, dtype),
        dKV: T.Tensor(k_shape, accum_dtype),
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    ):
        with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz):
            Q_shared = T.alloc_shared([padded_H, D], dtype)
            Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
            KV_shared = T.alloc_shared([BS, D], dtype)
            KV_tail_shared = T.alloc_shared([BS, D_tail], dtype)
            dO_shared = T.alloc_shared([padded_H, D], dtype)
            mask = T.alloc_fragment([BS], "bool")

            P_shared_cast = T.alloc_shared([padded_H, BS], dtype)
            dP_shared_cast = T.alloc_shared([padded_H, BS], dtype)
            dQ_shared = T.alloc_shared([padded_H, D], dtype)
            dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)

            acc_p = T.alloc_fragment([padded_H, BS], accum_dtype)
            acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype)
            acc_dq = T.alloc_fragment([padded_H, D], accum_dtype)
            acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype)
            acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
            acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
            acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
166
            acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
167
168
169
170
171
172

            b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
            bos, eos = Offsets[b_i], Offsets[b_i + 1]

            max_kv_i = s_i

173
174
175
            T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared)
            T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared)
            T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared)
176
177
178
179

            T.clear(acc_dq)
            T.clear(acc_dq_tail)

180
181
182
183
184
185
            T.annotate_layout(
                {
                    dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
                    dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
                }
            )
186
187
188
189
190

            # Process each block of indices
            for i_i in T.Pipelined(NS, num_stages=num_stages):
                # Check which indices are valid
                for bi_i in T.Parallel(BS):
191
                    mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1)
192
193
194
195
196
197
198

                # Compute attention scores
                for h_i, bi_i in T.Parallel(padded_H, BS):
                    acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype))

                # Load KV, V for this block of indices
                for bi_i, d_i in T.Parallel(BS, D):
199
                    KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i]
200

201
                T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
202
203

                for bi_i, d_i in T.Parallel(BS, D_tail):
204
205
                    KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i]
                T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
206
207

                for h_i, bi_i in T.Parallel(padded_H, BS):
208
                    acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i])
209
210
211

                T.copy(acc_p, P_shared_cast)

212
                T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
213
214

                for h_i, bi_i in T.Parallel(padded_H, BS):
215
                    acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale
216
217
218
219
220

                T.copy(acc_dp, dP_shared_cast)
                T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
                T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)

221
222
                T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
                T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
223
224

                T.clear(acc_dkv_tail)
225
                T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
226
227
228
229
230
231
232
233

                for s in range(split_store):
                    for bi_i, d_i in T.Parallel(BS, D):
                        if bi_i < BS // split_store:
                            acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i]

                    for bi_i, d_i in T.Parallel(BS, D_tail):
                        if bi_i < BS // split_store:
234
                            acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i]
235
236
237

                    for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
                        T.atomic_addx4(
238
239
240
                            dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4],
                            acc_dkv_shared[bi_i, d_i * 4],
                        )
241
242
243
244

                    # Atomically update dKV, dKV_tail tensors
                    for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
                        T.atomic_addx4(
245
246
247
                            dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4],
                            acc_dkv_tail_shared[bi_i, d_i * 4],
                        )
248
249
250
251
252

            # Store the accumulated dQ
            T.copy(acc_dq, dQ_shared)
            T.copy(acc_dq_tail, dQ_tail_shared)

253
254
            T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D])
            T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:])
255
256
257
258

    return sparse_mla_bwd_kernel


259
def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None):
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    assert q.is_contiguous()
    assert kv.is_contiguous()
    assert indices.is_contiguous()
    assert lse.is_contiguous()
    S, H, dim_plus_tail_dim = q.shape
    S_kv, kv_group, _ = kv.shape
    assert kv.shape[-1] == dim_plus_tail_dim
    assert S == S_kv
    # dim should be assigned
    D = 512

    D_tail = dim_plus_tail_dim - D
    topk = indices.shape[-1]
    assert indices.shape == (S, kv_group, topk)
    assert lse.shape == (S, H)

    token_indices = prepare_token_indices(offsets)

    # Get kernels
    preprocess_kernel = preprocess(H, D)
    bwd_kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_casual)
    postprocess_kernel = postprocess(D, D_tail, kv_group)

    if delta is None:
        delta = preprocess_kernel(o, do)
    dkv = torch.zeros_like(kv, dtype=torch.float32)
    dq = bwd_kernel(q, kv, do, indices, lse, delta, offsets, token_indices, dkv)
    dkv = postprocess_kernel(dkv)

    return dq, dkv


292
def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True):
293
    from sparse_mla_fwd import ref_sparse_mla_fwd_interface
294

295
296
297
298
299
300
301
302
303
    q = q.detach().clone()
    kv = kv.detach().clone()
    q.requires_grad = True
    kv.requires_grad = True
    o = ref_sparse_mla_fwd_interface(q, kv, indices, offsets, sm_scale, is_casual)
    o.backward(do)
    return q.grad, kv.grad


304
def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True):
305
    # Prepare data
306
307
308
    q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
    kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
    do = torch.randn((S, H, DV), dtype=dtype, device="cuda")
309
310
    offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda")

311
    indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda")
312
313
314
315
316
317
    for i in range(offsets.shape[0] - 1):
        seq_len = (offsets[i + 1] - offsets[i]).item()
        assert seq_len >= topk
        for t in range(seq_len):
            for h in range(HKV):
                i_i = torch.randperm(max(1, t))[:topk]
318
                indices[offsets[i] + t, h, : len(i_i)] = i_i
319
320
321

    # Forward
    from sparse_mla_fwd import sparse_mla_fwd_interface
322

323
324
325
326
327
328
329
330
331
332
    tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets)

    tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)
    ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None, offsets)

    if check_correctness:
        assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq")
        assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
        print("assert_tensors_similar passed")

333
334
335
336
337
338
339
340
341
    per_token_flop = 2 * sum(
        [
            H * DV * topk,
            H * DQKV * topk,
            H * DQKV * topk,
            H * DQKV * topk,
            H * DV * topk,
        ]
    )
342
343
344
345
346
347
348
    from tilelang.profiler import do_bench

    def fn():
        return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)

    ms = do_bench(fn, rep=100, warmup=250)
    print(f"Average time: {ms:.3f} ms")
349
350
    print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
    print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12)
351
352
353


if __name__ == "__main__":
354
    test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True)