example_warp_specialize_flashmla.py 18 KB
Newer Older
1
2
3
import torch
import torch.nn.functional as F
import tilelang
4
from tilelang.autotuner import *
5
6
import tilelang.language as T
from einops import rearrange, einsum
7
8
9
import argparse

tilelang.disable_cache()
10
11


12
@tilelang.jit(out_idx=[6])
13
14
15
16
17
18
19
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
    scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504  # log2(e)
    dtype = "float16"
    accum_dtype = "float"
    kv_group_num = heads // kv_head_num
    VALID_BLOCK_H = min(block_H, kv_group_num)
    assert kv_head_num == 1, "kv_head_num must be 1"
20
    h_dim = dim // 2
21
22
23
24
25
26
27
28
29

    @T.macro
    def flash_attn(
            Q: T.Tensor([batch, heads, dim], dtype),
            Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
            KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
            K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
            Output: T.Tensor([batch, heads, dim], dtype),
    ):
30
        with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
31
            # smem_sQ
32
33
            Q_shared_l = T.alloc_shared([block_H, h_dim], dtype)
            Q_shared_r = T.alloc_shared([block_H, h_dim], dtype)
34
            Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
35
36
37
38
            Q_pe_local_0 = T.alloc_fragment([block_H, pe_dim], dtype)
            Q_pe_local_1 = T.alloc_fragment([block_H, pe_dim], dtype)

            # smem_sK0
39
40
            KV_shared_0_l = T.alloc_shared([block_N, h_dim], dtype)
            KV_shared_0_r = T.alloc_shared([block_N, h_dim], dtype)
41
42
43
            K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)

            # smem_sK1
44
45
46
            KV_shared_1_l = T.alloc_shared([block_N, h_dim], dtype)
            KV_shared_1_r = T.alloc_shared([block_N, h_dim], dtype)
            K_pe_shared_1 = T.alloc_shared([block_N, pe_dim], dtype)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

            # smem_sP0
            SP0_shared = T.alloc_shared([block_H, block_N], dtype)

            # smem_sP1 reuse Q_pe_shared
            SP1_shared = Q_pe_shared

            # smem_sM
            scores_max = T.alloc_shared([block_H], accum_dtype)

            # smem_sScale0
            scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
            # smem_sScale1
            scores_scale_1 = T.alloc_shared([block_H], accum_dtype)

            logsum = T.alloc_shared([block_H], accum_dtype)

64
65
66
67
            O_shared_l = Q_shared_l
            O_shared_r = Q_shared_r

            acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
68
            acc_s_0_cast = T.alloc_fragment([block_H, block_N], dtype)
69
            acc_s_1 = T.alloc_fragment([block_H, block_N], accum_dtype)
70
            acc_s_1_cast = T.alloc_fragment([block_H, block_N], dtype)
71
72
73
74
            acc_o_l = T.alloc_fragment([block_H, h_dim], accum_dtype)
            acc_o_r = T.alloc_fragment([block_H, h_dim], accum_dtype)
            scores_max_0 = T.alloc_fragment([block_H], accum_dtype)
            scores_max_1 = T.alloc_fragment([block_H], accum_dtype)
75
76
77

            scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype)
78

79
80
81
82
83
84
85
            scores_sum_0 = T.alloc_fragment([block_H], accum_dtype)
            scores_sum_1 = T.alloc_fragment([block_H], accum_dtype)
            logsum_0 = T.alloc_fragment([block_H], accum_dtype)
            logsum_1 = T.alloc_fragment([block_H], accum_dtype)

            cur_kv_head = hid // (kv_group_num // block_H)

86
            T.annotate_layout({
87
88
                O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l),
                O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
89
            })
90

91
92
93
94
            # barriers_Q
            q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)

            # barriers_K0
95
96
97
            kv_shared_0_l_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_0_r_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_0_pe_is_ready = T.alloc_barrier(arrive_count=128)
98
            # barriers_K1
99
100
101
            kv_shared_1_l_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_1_r_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_1_pe_is_ready = T.alloc_barrier(arrive_count=128)
102
103

            # redundant barriers
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            score_max_0_ready_barrier = T.alloc_barrier(arrive_count=128)
            scale_1_ready_barrier = T.alloc_barrier(arrive_count=128)
            p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128)
            lse_0_ready_barrier = T.alloc_barrier(arrive_count=128)
            lse_1_ready_barrier = T.alloc_barrier(arrive_count=128)
            s_shared_ready_barrier = T.alloc_barrier(arrive_count=128)

            tx = T.get_thread_binding()

            T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l)
            T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r)
            T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
            T.barrier_arrive(q_shared_ready_barrier)
            T.barrier_wait(q_shared_ready_barrier, 0)
118

119
120
121
122
123
            T.fill(scores_max, -T.infinity(accum_dtype))

            loop_range = T.ceildiv(seqlen_kv, (block_N * 2))

            if tx < 128:
124
                T.copy(Q_pe_shared, Q_pe_local_0)
125
126
127
128
129
130
131
132
133
134
135
136
                T.fill(acc_o_l, 0)
                T.fill(logsum_0, 0)

                T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
                T.barrier_arrive(kv_shared_1_l_is_ready)

                T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
                T.barrier_arrive(kv_shared_1_r_is_ready)

                T.copy(K_pe[bid, block_N:2 * block_N, cur_kv_head, :], K_pe_shared_1)
                T.barrier_arrive(kv_shared_1_pe_is_ready)

137
                for k in T.serial(loop_range):
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

                    T.barrier_wait(kv_shared_0_l_is_ready, k % 2)
                    T.gemm(
                        Q_shared_l,
                        KV_shared_0_l,
                        acc_s_0,
                        transpose_B=True,
                        clear_accum=True,
                        wg_wait=-1)
                    T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
                    T.gemm(
                        Q_shared_r,
                        KV_shared_0_r,
                        acc_s_0,
                        transpose_B=True,
                        wg_wait=-1)

                    T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
                    T.gemm(
157
                        Q_pe_local_0,
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
                        K_pe_shared_0,
                        acc_s_0,
                        transpose_B=True,
                        wg_wait=-1)

                    T.wait_wgmma(0)

                    # Step 3.
                    T.copy(scores_max, scores_max_0)
                    T.copy(scores_max_0, scores_max_prev_0)
                    T.fill(scores_max_0, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s_0, scores_max_0, dim=1, clear=False)
                    T.copy(scores_max_0, scores_max)

                    # Step 4.
                    for i, j in T.Parallel(block_H, block_N):
                        acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale)
                    for i in T.Parallel(block_H):
                        scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale -
                                                   scores_max[i] * scale)

                    T.reduce_sum(acc_s_0, scores_sum_0, dim=1)

                    # Step 5.
182
                    T.copy(acc_s_0, acc_s_0_cast)
183
184
185
186
187
188
189
190

                    for i, j in T.Parallel(block_H, h_dim):
                        acc_o_l[i, j] *= scores_scale_0[i]

                    for i in T.Parallel(block_H):
                        logsum_0[i] = logsum_0[i] * scores_scale_0[i] + scores_sum_0[i]

                    # Step 6.
191
                    T.gemm(acc_s_0_cast, KV_shared_0_l, acc_o_l)
192
193
194
195
196
197
198
199
200
201
202
203
                    T.barrier_arrive(score_max_0_ready_barrier)

                    T.barrier_wait(scale_1_ready_barrier, k % 2)

                    if k < loop_range - 1:
                        T.copy(
                            KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N,
                               cur_kv_head, :h_dim], KV_shared_0_l)
                        T.barrier_arrive(kv_shared_0_l_is_ready)

                    # Step 11.
                    for i, j in T.Parallel(block_H, block_N):
204
                        SP0_shared[i, j] = acc_s_0[i, j] * scores_scale_1[i]
205
206
207
208
209
210
211
212
213
214
215

                    T.barrier_arrive(p0_1_1_ready_barrier)

                    # Step 13.
                    for i, j in T.Parallel(block_H, h_dim):
                        acc_o_l[i, j] *= scores_scale_1[i]
                    for i in T.Parallel(block_H):
                        logsum_0[i] = logsum_0[i] * scores_scale_1[i]
                    T.barrier_wait(s_shared_ready_barrier, k % 2)

                    # Step 14.
216
                    T.gemm(SP1_shared, KV_shared_1_l, acc_o_l)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

                    if k < loop_range - 1:

                        T.copy(
                            KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N,
                               cur_kv_head, :h_dim], KV_shared_1_l)
                        T.barrier_arrive(kv_shared_1_l_is_ready)

                        T.copy(
                            K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :],
                            K_pe_shared_1)
                        T.barrier_arrive(kv_shared_1_pe_is_ready)

                T.copy(logsum_0, logsum)
                T.barrier_arrive(lse_0_ready_barrier)
                T.barrier_wait(lse_1_ready_barrier, 0)
                for i, j in T.Parallel(block_H, h_dim):
                    acc_o_l[i, j] /= logsum[i]
                T.copy(acc_o_l, O_shared_l)
                T.copy(O_shared_l, Output[bid,
                                          hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim])

            else:
240
                T.copy(Q_pe_shared, Q_pe_local_1)
241
242
243
244
245
246
247
248
249
250
                T.fill(acc_o_r, 0)
                T.fill(logsum_1, 0)

                T.copy(KV[bid, :block_N, cur_kv_head, :h_dim], KV_shared_0_l)
                T.barrier_arrive(kv_shared_0_l_is_ready)
                T.copy(KV[bid, :block_N, cur_kv_head, h_dim:], KV_shared_0_r)
                T.barrier_arrive(kv_shared_0_r_is_ready)
                T.copy(K_pe[bid, :block_N, cur_kv_head, :], K_pe_shared_0)
                T.barrier_arrive(kv_shared_0_pe_is_ready)

251
                for k in T.serial(loop_range):
252
253
254

                    # Step 2.
                    T.barrier_wait(kv_shared_1_l_is_ready, k % 2)
255
                    T.gemm(
256
257
258
                        Q_shared_l,
                        KV_shared_1_l,
                        acc_s_1,
259
                        transpose_B=True,
260
261
262
263
264
265
266
267
268
269
270
271
                        clear_accum=True,
                        wg_wait=-1)

                    T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
                    T.gemm(
                        Q_shared_r,
                        KV_shared_1_r,
                        acc_s_1,
                        transpose_B=True,
                        wg_wait=-1)

                    T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
272
                    T.gemm(
273
                        Q_pe_local_1,
274
275
                        K_pe_shared_1,
                        acc_s_1,
276
                        transpose_B=True,
277
278
279
280
281
282
283
284
285
286
287
288
                        wg_wait=-1)

                    T.wait_wgmma(0)

                    # Step 7.
                    T.barrier_wait(score_max_0_ready_barrier, k % 2)

                    T.copy(scores_max, scores_max_prev_1)
                    T.fill(scores_max_1, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s_1, scores_max_1, dim=1, clear=False)
                    T.copy(scores_max_1, scores_max)

289
                    for i in T.Parallel(block_H):
290
291
292
293
                        scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale -
                                                   scores_max[i] * scale)

                    # Step 8.
294
                    for i, j in T.Parallel(block_H, block_N):
295
296
297
298
299
300
301
302
                        acc_s_1[i, j] = T.exp2(acc_s_1[i, j] * scale - scores_max[i] * scale)

                    # Step 9.
                    T.reduce_sum(acc_s_1, scores_sum_1, dim=1)

                    for i, j in T.Parallel(block_H, h_dim):
                        acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i])

303
                    for i in T.Parallel(block_H):
304
305
306
307
308
309
                        logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[
                            i] + scores_sum_1[i]

                    T.barrier_arrive(scale_1_ready_barrier)

                    # Step 10. compute O1 with KV_shared_1_rd
310
                    T.copy(acc_s_1, acc_s_1_cast)
311
                    T.gemm(
312
                        acc_s_1_cast,
313
314
315
                        KV_shared_1_r,
                        acc_o_r,
                        wg_wait=-1)
316
317
                    T.copy(acc_s_1_cast, SP1_shared)
                    T.barrier_arrive(s_shared_ready_barrier)
318
319
320
321
322
323
324
325
326

                    if k < loop_range - 1:
                        T.copy(
                            KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head,
                               h_dim:], KV_shared_1_r)
                        T.barrier_arrive(kv_shared_1_r_is_ready)

                    T.barrier_wait(p0_1_1_ready_barrier, k % 2)
                    # Step 12.
327
                    T.gemm(SP0_shared, KV_shared_0_r, acc_o_r)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349

                    if k < loop_range - 1:

                        T.copy(
                            KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head,
                               h_dim:], KV_shared_0_r)
                        T.barrier_arrive(kv_shared_0_r_is_ready)

                        T.copy(
                            K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :],
                            K_pe_shared_0)
                        T.barrier_arrive(kv_shared_0_pe_is_ready)

                T.barrier_wait(lse_0_ready_barrier, 0)
                for i in T.Parallel(block_H):
                    logsum[i] += logsum_1[i]
                T.barrier_arrive(lse_1_ready_barrier)
                for i, j in T.Parallel(block_H, h_dim):
                    acc_o_r[i, j] /= logsum[i]
                T.copy(acc_o_r, O_shared_r)
                T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
                                          h_dim:])
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408

    @T.prim_func
    def main_no_split(
            Q: T.Tensor([batch, heads, dim], dtype),
            Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
            KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
            K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
            glse: T.Tensor([batch, heads, num_split], dtype),
            Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
            Output: T.Tensor([batch, heads, dim], dtype),
    ):
        flash_attn(Q, Q_pe, KV, K_pe, Output)

    return main_no_split


def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
    #     """
    #     Inputs:
    #     - q (Tensor): [batch, heads, dim]
    #     - q_pe (Tensor): [batch, heads, pe_dim]
    #     - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
    #     - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
    #     - glse (Tensor): [batch, heads, num_split]
    #     - Output_partial (Tensor): [batch, heads, num_split, dim]
    #     Outputs:
    #     - output (Tensor): [batch, heads, dim]
    #     """
    dim = q.shape[-1]
    pe_dim = q_pe.shape[-1]
    num_head_groups = q.shape[1] // kv.shape[2]
    scale = (dim + pe_dim)**0.5
    q = rearrange(
        q, 'b (h g) d -> b g h d', g=num_head_groups)  # [batch_size, num_head_groups, groups, dim]

    q_pe = rearrange(
        q_pe, 'b (h g) d -> b g h d',
        g=num_head_groups)  # [batch_size, num_head_groups, groups, pe_dim]

    kv = rearrange(kv, 'b n h d -> b h n d')  # [batch_size, groups, seqlen_kv, dim]

    k_pe = rearrange(k_pe, 'b n h d -> b h n d')  # [batch_size, num_head_groups, groups, pe_dim]

    query = torch.concat([q, q_pe], dim=-1)
    key = torch.concat([kv, k_pe], dim=-1)

    scores = einsum(
        query, key,
        'b g h d, b h s d -> b g h s')  # [batch_size, num_head_groups, groups, seqlen_kv]

    attention = F.softmax(
        scores / scale, dim=-1)  # [batch_size, num_head_groups, groups, seqlen_kv]

    out = einsum(attention, kv,
                 'b g h s, b h s d -> b g h d')  # [batch_size, num_head_groups, groups, dim]
    out = rearrange(out, 'b g h d -> b (h g) d')  # [batch_size, heads, dim]
    return out


409
def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
410
411
412
413
414
415
416
    qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
    pv_flops = 2 * batch * heads * kv_ctx * dim
    total_flops = qk_flops + pv_flops
    BLOCK_N = 64
    BLOCK_H = 64
    num_split = 1

417
    kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
418
419
420
421
422
423
424
425
    profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
    profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
    latency = profiler.do_bench(warmup=500)
    print(f"Latency: {latency} ms")
    print(f"TFlops: {total_flops / latency * 1e-9} TFlops")


if __name__ == "__main__":
426
    parser = argparse.ArgumentParser()
427
    parser.add_argument('--batch', type=int, default=1, help='batch size')
428
429
430
431
432
433
434
435
    parser.add_argument('--heads', type=int, default=128, help='q heads number')
    parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
    parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
    parser.add_argument('--dim', type=int, default=512, help='head dim')
    parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
    args = parser.parse_args()
    batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
    main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)