example_mla_decode_paged.py 18.1 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
from tilelang.profiler import do_bench
import math

9

10
@tilelang.jit(
11
12
    out_idx=[8],
    pass_configs={
13
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
14
15
16
    },
)
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None):
17
    if softmax_scale is None:
18
        softmax_scale = (dv + dpe) ** -0.5
19
    scale = float(softmax_scale * 1.44269504)  # log2(e)
20
21
22
23
24
25
26
27
28
    dtype = "float16"
    accum_dtype = "float"
    kv_group_num = h_q // h_kv
    VALID_BLOCK_H = min(block_H, kv_group_num)
    assert h_kv == 1, "h_kv must be 1"
    assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N"

    @T.macro
    def flash_mla_kernel(
29
30
31
32
33
34
35
        Q: T.Tensor([batch, h_q, dv], dtype),
        Q_pe: T.Tensor([batch, h_q, dpe], dtype),
        KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
        K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
        BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
        CACHE_SEQLENS: T.Tensor([batch], "int32"),
        Output: T.Tensor([batch, h_q, dv], dtype),
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    ):
        with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by):
            Q_shared = T.alloc_shared([block_H, dv], dtype)
            S_shared = T.alloc_shared([block_H, block_N], dtype)
            Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
            KV_shared = T.alloc_shared([block_N, dv], dtype)
            K_pe_shared = T.alloc_shared([block_N, dpe], dtype)
            O_shared = T.alloc_shared([block_H, dv], dtype)
            acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
            acc_o = T.alloc_fragment([block_H, dv], accum_dtype)
            scores_max = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
            scores_scale = T.alloc_fragment([block_H], accum_dtype)
            scores_sum = T.alloc_fragment([block_H], accum_dtype)
            logsum = T.alloc_fragment([block_H], accum_dtype)

            cur_kv_head = by // (kv_group_num // block_H)
            T.use_swizzle(10)
54
55
56
57
58
59
            T.annotate_layout(
                {
                    O_shared: tilelang.layout.make_swizzled_layout(O_shared),
                    S_shared: tilelang.layout.make_swizzled_layout(S_shared),
                }
            )
60

61
62
            T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
            T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
63
64
65
66
67
68
69
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N)
            for kr in T.Pipelined(loop_range, num_stages=2):
                k = loop_range - 1 - kr
70
71
72
                kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size
                T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared)
                T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared)
73
                T.clear(acc_s)
74
75
                T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
                T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
76
77
                T.copy(scores_max, scores_max_prev)
                T.fill(scores_max, -T.infinity(accum_dtype))
78
                if kr == 0:
79
                    for i, j in T.Parallel(block_H, block_N):
80
                        acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
81
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
82
83
                for i in T.Parallel(block_H):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
84
85
86
87
88
89
90
91
92
93
                for i in T.Parallel(block_H):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_H, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                T.copy(acc_s, S_shared)
                for i in T.Parallel(block_H):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                for i, j in T.Parallel(block_H, dv):
                    acc_o[i, j] *= scores_scale[i]
94
                T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
95
96
97
            for i, j in T.Parallel(block_H, dv):
                acc_o[i, j] /= logsum[i]
            T.copy(acc_o, O_shared)
98
            T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :])
99
100
101

    @T.macro
    def flash_mla_split_kv_kernel(
102
103
104
105
106
107
108
109
        Q: T.Tensor([batch, h_q, dv], dtype),
        Q_pe: T.Tensor([batch, h_q, dpe], dtype),
        KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
        K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
        BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
        CACHE_SEQLENS: T.Tensor([batch], "int32"),
        glse: T.Tensor([batch, h_q, num_split], dtype),
        Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
110
    ):
111
        with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
            Q_shared = T.alloc_shared([block_H, dv], dtype)
            S_shared = T.alloc_shared([block_H, block_N], dtype)
            Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
            KV_shared = T.alloc_shared([block_N, dv], dtype)
            K_pe_shared = T.alloc_shared([block_N, dpe], dtype)
            O_shared = T.alloc_shared([block_H, dv], dtype)
            acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
            acc_o = T.alloc_fragment([block_H, dv], accum_dtype)
            scores_max = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
            scores_scale = T.alloc_fragment([block_H], accum_dtype)
            scores_sum = T.alloc_fragment([block_H], accum_dtype)
            logsum = T.alloc_fragment([block_H], accum_dtype)

            cur_kv_head = by // (kv_group_num // block_H)
            T.use_swizzle(10)
129
130
131
132
133
134
            T.annotate_layout(
                {
                    O_shared: tilelang.layout.make_swizzled_layout(O_shared),
                    S_shared: tilelang.layout.make_swizzled_layout(S_shared),
                }
            )
135

136
137
            T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
            T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
138
139
140
141
142
143
144
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N)
            blocks_per_split = T.floordiv(total_blocks, num_split)
            remaining_blocks = T.floormod(total_blocks, num_split)
145
            loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)
146
147
148
            start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N

            for k in T.Pipelined(loop_range, num_stages=2):
149
150
151
                kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
                T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared)
                T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared)
152
                T.clear(acc_s)
153
154
                T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
                T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
155
156
157
                T.copy(scores_max, scores_max_prev)
                T.fill(scores_max, -T.infinity(accum_dtype))
                for i, j in T.Parallel(block_H, block_N):
158
                    acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
159
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
160
161
                for i in T.Parallel(block_H):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
                for i in T.Parallel(block_H):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_H, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                T.copy(acc_s, S_shared)
                T.copy(S_shared, acc_s_cast)
                for i in T.Parallel(block_H):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                for i, j in T.Parallel(block_H, dv):
                    acc_o[i, j] *= scores_scale[i]
                T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
            for i, j in T.Parallel(block_H, dv):
                acc_o[i, j] /= logsum[i]
            for i in T.Parallel(block_H):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
178
            T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz])
179
            T.copy(acc_o, O_shared)
180
            T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :])
181
182
183

    @T.macro
    def combine(
184
185
186
        glse: T.Tensor([batch, h_q, num_split], dtype),
        Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
        Output: T.Tensor([batch, h_q, dv], dtype),
187
188
189
190
191
192
193
194
195
    ):
        with T.Kernel(h_q, batch, threads=128) as (by, bz):
            po_local = T.alloc_fragment([dv], dtype)
            o_accum_local = T.alloc_fragment([dv], accum_dtype)
            lse_local_split = T.alloc_local([1], accum_dtype)
            lse_logsum_local = T.alloc_local([1], accum_dtype)
            lse_max_local = T.alloc_local([1], accum_dtype)
            scale_local = T.alloc_local([1], accum_dtype)

196
197
198
199
200
            T.annotate_layout(
                {
                    lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
                }
            )
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

            T.clear(lse_logsum_local)
            T.clear(o_accum_local)
            lse_max_local[0] = -T.infinity(accum_dtype)
            for k in T.serial(num_split):
                lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
            for k in T.Pipelined(num_split, num_stages=1):
                lse_local_split[0] = glse[bz, by, k]
                lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
            lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
            for k in T.serial(num_split):
                for i in T.Parallel(dv):
                    po_local[i] = Output_partial[bz, by, k, i]
                lse_local_split[0] = glse[bz, by, k]
                scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
                for i in T.Parallel(dv):
                    o_accum_local[i] += po_local[i] * scale_local[0]
            for i in T.Parallel(dv):
                Output[bz, by, i] = o_accum_local[i]

    @T.prim_func
    def main_split(
223
224
225
226
227
228
229
230
231
        Q: T.Tensor([batch, h_q, dv], dtype),
        Q_pe: T.Tensor([batch, h_q, dpe], dtype),
        KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
        K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
        block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
        cache_seqlens: T.Tensor([batch], "int32"),
        glse: T.Tensor([batch, h_q, num_split], dtype),
        Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
        Output: T.Tensor([batch, h_q, dv], dtype),
232
    ):
233
        flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial)
234
235
236
237
        combine(glse, Output_partial, Output)

    @T.prim_func
    def main_no_split(
238
239
240
241
242
243
244
245
246
        Q: T.Tensor([batch, h_q, dv], dtype),
        Q_pe: T.Tensor([batch, h_q, dpe], dtype),
        KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
        K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
        block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
        cache_seqlens: T.Tensor([batch], "int32"),
        glse: T.Tensor([batch, h_q, num_split], dtype),
        Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
        Output: T.Tensor([batch, h_q, dv], dtype),
247
248
249
250
251
252
253
254
    ):
        flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output)

    if num_split > 1:
        return main_split
    else:
        return main_no_split

255

256
257
258
259
260
261
262
263
264
265
266
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
    query = query.float()
    key = key.float()
    value = value.float()
    key = key.repeat_interleave(h_q // h_kv, dim=0)
    value = value.repeat_interleave(h_q // h_kv, dim=0)
    attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
    if is_causal:
        s_q = query.shape[-2]
        s_k = key.shape[-2]
        attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device)
267
        temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
268
269
270
271
272
273
274
275
276
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)
        attn_weight += attn_bias
    lse = attn_weight.logsumexp(dim=-1)
    attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
    return attn_weight @ value, lse


@torch.inference_mode()
277
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    # q: [b, s_q, h_q, d]
    # block_table: [b, max_seqlen_pad // block_size]
    # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
    # cache_seqlens: [b]
    blocked_v = blocked_k[..., :dv]

    def ref_mla():
        out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device)
        lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device)
        for i in range(b):
            begin = i * max_seqlen_pad
            end = begin + cache_seqlens[i]
            O, LSE = scaled_dot_product_attention(
                q[i].transpose(0, 1),
                blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
                blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
294
295
                h_q,
                h_kv,
296
297
298
299
300
301
302
303
304
305
                is_causal=causal,
            )
            out[i] = O.transpose(0, 1)
            lse[i] = LSE
        return out.to(dtype), lse.to(dtype)

    out_torch, _ = ref_mla()
    return out_torch


306
def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
307
308
    assert d > dv, "mla with rope dim should be larger than no rope dim"
    q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
309
    blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
310
311
312
313

    dpe = d - dv
    num_kv_splits = 1
    BLOCK_N = 64
314
    BLOCK_H = min(64, h_q // h_kv)
315
    softmax_scale = d**-0.5
316

317
318
    out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
    glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
319
    kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale)
320
    profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
321
322

    def flash_mla_tilelang():
323
324
325
326
327
328
        out = profiler.func(
            q_nope.view(-1, h_q, dv),
            q_pe.view(-1, h_q, dpe),
            blocked_k_nope.view(-1, h_kv, dv),
            blocked_k_pe.view(-1, h_kv, dpe),
            block_table,
329
330
331
332
333
334
335
336
            cache_seqlens,
            glse,
            out_partial,
        )
        return out.view([b, s_q, h_q, dv])

    out_flash = flash_mla_tilelang()
    t = do_bench(flash_mla_tilelang)
337
    out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
338
339
340
341
    torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01)
    print("All close")
    return out_flash, t

342

343
344
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
345
346
347
348
349
350
    parser.add_argument("--batch", type=int, default=128, help="batch size")
    parser.add_argument("--h_q", type=int, default=128, help="q heads number")
    parser.add_argument("--h_kv", type=int, default=1, help="kv heads number")
    parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length")
    parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe")
    parser.add_argument("--dv", type=int, default=512, help="value head dim")
351
352
353
354
355
    args = parser.parse_args()
    b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv

    device = "cuda"
    dtype = torch.float16
356
357

    s_q = 1  # for decode, s_q = 1
358
    block_size = 64
359
    cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device)
360
361
362
363
364
365
366
367
    dpe = d - dv
    causal = True

    total_seqlens = cache_seqlens.sum().item()
    mean_seqlens = cache_seqlens.float().mean().int().item()
    max_seqlen = cache_seqlens.max().item()
    max_seqlen_pad = math.ceil(max_seqlen / 256) * 256

368
    total_flops = s_q * total_seqlens * h_q * d * 2
369
370

    q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
371
    block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size)
372
    blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device)
373
374
375
    out_flash, latency = run_tilelang_mla(
        q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
    )
376
377

    print("Tile-lang: {:.2f} ms".format(latency))
378
    print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))