example_warp_specialize_flashmla.py 17.4 KB
Newer Older
1
2
3
import torch
import torch.nn.functional as F
import tilelang
4
from tilelang.autotuner import *
5
6
import tilelang.language as T
from einops import rearrange, einsum
7
8
import argparse

9

10
@tilelang.jit(out_idx=[6])
11
12
13
14
15
16
17
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
    scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504  # log2(e)
    dtype = "float16"
    accum_dtype = "float"
    kv_group_num = heads // kv_head_num
    VALID_BLOCK_H = min(block_H, kv_group_num)
    assert kv_head_num == 1, "kv_head_num must be 1"
18
    h_dim = dim // 2
19
20
21
22
23
24
25
26
27

    @T.macro
    def flash_attn(
            Q: T.Tensor([batch, heads, dim], dtype),
            Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
            KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
            K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
            Output: T.Tensor([batch, heads, dim], dtype),
    ):
28
        with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
29
            # smem_sQ
30
31
            Q_shared_l = T.alloc_shared([block_H, h_dim], dtype)
            Q_shared_r = T.alloc_shared([block_H, h_dim], dtype)
32
            Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
33
34
35
36
            Q_pe_local_0 = T.alloc_fragment([block_H, pe_dim], dtype)
            Q_pe_local_1 = T.alloc_fragment([block_H, pe_dim], dtype)

            # smem_sK0
37
38
            KV_shared_0_l = T.alloc_shared([block_N, h_dim], dtype)
            KV_shared_0_r = T.alloc_shared([block_N, h_dim], dtype)
39
40
41
            K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)

            # smem_sK1
42
43
44
            KV_shared_1_l = T.alloc_shared([block_N, h_dim], dtype)
            KV_shared_1_r = T.alloc_shared([block_N, h_dim], dtype)
            K_pe_shared_1 = T.alloc_shared([block_N, pe_dim], dtype)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

            # smem_sP0
            SP0_shared = T.alloc_shared([block_H, block_N], dtype)

            # smem_sP1 reuse Q_pe_shared
            SP1_shared = Q_pe_shared

            # smem_sM
            scores_max = T.alloc_shared([block_H], accum_dtype)

            # smem_sScale0
            scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
            # smem_sScale1
            scores_scale_1 = T.alloc_shared([block_H], accum_dtype)

            logsum = T.alloc_shared([block_H], accum_dtype)

62
63
64
65
            O_shared_l = Q_shared_l
            O_shared_r = Q_shared_r

            acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
66
            acc_s_0_cast = T.alloc_fragment([block_H, block_N], dtype)
67
            acc_s_1 = T.alloc_fragment([block_H, block_N], accum_dtype)
68
            acc_s_1_cast = T.alloc_fragment([block_H, block_N], dtype)
69
70
71
72
            acc_o_l = T.alloc_fragment([block_H, h_dim], accum_dtype)
            acc_o_r = T.alloc_fragment([block_H, h_dim], accum_dtype)
            scores_max_0 = T.alloc_fragment([block_H], accum_dtype)
            scores_max_1 = T.alloc_fragment([block_H], accum_dtype)
73
74
75

            scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype)
76

77
78
79
80
81
82
83
            scores_sum_0 = T.alloc_fragment([block_H], accum_dtype)
            scores_sum_1 = T.alloc_fragment([block_H], accum_dtype)
            logsum_0 = T.alloc_fragment([block_H], accum_dtype)
            logsum_1 = T.alloc_fragment([block_H], accum_dtype)

            cur_kv_head = hid // (kv_group_num // block_H)

84
            T.annotate_layout({
85
86
                O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l),
                O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
87
            })
88

89
90
91
92
            # barriers_Q
            q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)

            # barriers_K0
93
94
95
            kv_shared_0_l_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_0_r_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_0_pe_is_ready = T.alloc_barrier(arrive_count=128)
96
            # barriers_K1
97
98
99
            kv_shared_1_l_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_1_r_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_1_pe_is_ready = T.alloc_barrier(arrive_count=128)
100
101

            # redundant barriers
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            score_max_0_ready_barrier = T.alloc_barrier(arrive_count=128)
            scale_1_ready_barrier = T.alloc_barrier(arrive_count=128)
            p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128)
            lse_0_ready_barrier = T.alloc_barrier(arrive_count=128)
            lse_1_ready_barrier = T.alloc_barrier(arrive_count=128)
            s_shared_ready_barrier = T.alloc_barrier(arrive_count=128)

            tx = T.get_thread_binding()

            T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l)
            T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r)
            T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
            T.barrier_arrive(q_shared_ready_barrier)
            T.barrier_wait(q_shared_ready_barrier, 0)
116

117
118
119
120
121
            T.fill(scores_max, -T.infinity(accum_dtype))

            loop_range = T.ceildiv(seqlen_kv, (block_N * 2))

            if tx < 128:
122
                T.copy(Q_pe_shared, Q_pe_local_0)
123
124
125
126
127
128
129
130
131
132
133
134
                T.fill(acc_o_l, 0)
                T.fill(logsum_0, 0)

                T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
                T.barrier_arrive(kv_shared_1_l_is_ready)

                T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
                T.barrier_arrive(kv_shared_1_r_is_ready)

                T.copy(K_pe[bid, block_N:2 * block_N, cur_kv_head, :], K_pe_shared_1)
                T.barrier_arrive(kv_shared_1_pe_is_ready)

135
                for k in T.serial(loop_range):
136
137
138
139
140
141
142
143
144
145

                    T.barrier_wait(kv_shared_0_l_is_ready, k % 2)
                    T.gemm(
                        Q_shared_l,
                        KV_shared_0_l,
                        acc_s_0,
                        transpose_B=True,
                        clear_accum=True,
                        wg_wait=-1)
                    T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
146
                    T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1)
147
148

                    T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
149
                    T.gemm(Q_pe_local_0, K_pe_shared_0, acc_s_0, transpose_B=True, wg_wait=-1)
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

                    T.wait_wgmma(0)

                    # Step 3.
                    T.copy(scores_max, scores_max_0)
                    T.copy(scores_max_0, scores_max_prev_0)
                    T.fill(scores_max_0, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s_0, scores_max_0, dim=1, clear=False)
                    T.copy(scores_max_0, scores_max)

                    # Step 4.
                    for i, j in T.Parallel(block_H, block_N):
                        acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale)
                    for i in T.Parallel(block_H):
                        scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale -
                                                   scores_max[i] * scale)

                    T.reduce_sum(acc_s_0, scores_sum_0, dim=1)

                    # Step 5.
170
                    T.copy(acc_s_0, acc_s_0_cast)
171
172
173
174
175
176
177
178

                    for i, j in T.Parallel(block_H, h_dim):
                        acc_o_l[i, j] *= scores_scale_0[i]

                    for i in T.Parallel(block_H):
                        logsum_0[i] = logsum_0[i] * scores_scale_0[i] + scores_sum_0[i]

                    # Step 6.
179
                    T.gemm(acc_s_0_cast, KV_shared_0_l, acc_o_l)
180
181
182
183
184
185
186
187
188
189
190
191
                    T.barrier_arrive(score_max_0_ready_barrier)

                    T.barrier_wait(scale_1_ready_barrier, k % 2)

                    if k < loop_range - 1:
                        T.copy(
                            KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N,
                               cur_kv_head, :h_dim], KV_shared_0_l)
                        T.barrier_arrive(kv_shared_0_l_is_ready)

                    # Step 11.
                    for i, j in T.Parallel(block_H, block_N):
192
                        SP0_shared[i, j] = acc_s_0[i, j] * scores_scale_1[i]
193
194
195
196
197
198
199
200
201
202
203

                    T.barrier_arrive(p0_1_1_ready_barrier)

                    # Step 13.
                    for i, j in T.Parallel(block_H, h_dim):
                        acc_o_l[i, j] *= scores_scale_1[i]
                    for i in T.Parallel(block_H):
                        logsum_0[i] = logsum_0[i] * scores_scale_1[i]
                    T.barrier_wait(s_shared_ready_barrier, k % 2)

                    # Step 14.
204
                    T.gemm(SP1_shared, KV_shared_1_l, acc_o_l)
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

                    if k < loop_range - 1:

                        T.copy(
                            KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N,
                               cur_kv_head, :h_dim], KV_shared_1_l)
                        T.barrier_arrive(kv_shared_1_l_is_ready)

                        T.copy(
                            K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :],
                            K_pe_shared_1)
                        T.barrier_arrive(kv_shared_1_pe_is_ready)

                T.copy(logsum_0, logsum)
                T.barrier_arrive(lse_0_ready_barrier)
                T.barrier_wait(lse_1_ready_barrier, 0)
                for i, j in T.Parallel(block_H, h_dim):
                    acc_o_l[i, j] /= logsum[i]
                T.copy(acc_o_l, O_shared_l)
                T.copy(O_shared_l, Output[bid,
                                          hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim])

            else:
228
                T.copy(Q_pe_shared, Q_pe_local_1)
229
230
231
232
233
234
235
236
237
238
                T.fill(acc_o_r, 0)
                T.fill(logsum_1, 0)

                T.copy(KV[bid, :block_N, cur_kv_head, :h_dim], KV_shared_0_l)
                T.barrier_arrive(kv_shared_0_l_is_ready)
                T.copy(KV[bid, :block_N, cur_kv_head, h_dim:], KV_shared_0_r)
                T.barrier_arrive(kv_shared_0_r_is_ready)
                T.copy(K_pe[bid, :block_N, cur_kv_head, :], K_pe_shared_0)
                T.barrier_arrive(kv_shared_0_pe_is_ready)

239
                for k in T.serial(loop_range):
240
241
242

                    # Step 2.
                    T.barrier_wait(kv_shared_1_l_is_ready, k % 2)
243
                    T.gemm(
244
245
246
                        Q_shared_l,
                        KV_shared_1_l,
                        acc_s_1,
247
                        transpose_B=True,
248
249
250
251
                        clear_accum=True,
                        wg_wait=-1)

                    T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
252
                    T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1)
253
254

                    T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
255
                    T.gemm(Q_pe_local_1, K_pe_shared_1, acc_s_1, transpose_B=True, wg_wait=-1)
256
257
258
259
260
261
262
263
264
265
266

                    T.wait_wgmma(0)

                    # Step 7.
                    T.barrier_wait(score_max_0_ready_barrier, k % 2)

                    T.copy(scores_max, scores_max_prev_1)
                    T.fill(scores_max_1, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s_1, scores_max_1, dim=1, clear=False)
                    T.copy(scores_max_1, scores_max)

267
                    for i in T.Parallel(block_H):
268
269
270
271
                        scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale -
                                                   scores_max[i] * scale)

                    # Step 8.
272
                    for i, j in T.Parallel(block_H, block_N):
273
274
275
276
277
278
279
280
                        acc_s_1[i, j] = T.exp2(acc_s_1[i, j] * scale - scores_max[i] * scale)

                    # Step 9.
                    T.reduce_sum(acc_s_1, scores_sum_1, dim=1)

                    for i, j in T.Parallel(block_H, h_dim):
                        acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i])

281
                    for i in T.Parallel(block_H):
282
283
284
285
286
287
                        logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[
                            i] + scores_sum_1[i]

                    T.barrier_arrive(scale_1_ready_barrier)

                    # Step 10. compute O1 with KV_shared_1_rd
288
                    T.copy(acc_s_1, acc_s_1_cast)
289
                    T.gemm(acc_s_1_cast, KV_shared_1_r, acc_o_r, wg_wait=-1)
290
291
                    T.copy(acc_s_1_cast, SP1_shared)
                    T.barrier_arrive(s_shared_ready_barrier)
292
293
294
295
296
297
298
299
300

                    if k < loop_range - 1:
                        T.copy(
                            KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head,
                               h_dim:], KV_shared_1_r)
                        T.barrier_arrive(kv_shared_1_r_is_ready)

                    T.barrier_wait(p0_1_1_ready_barrier, k % 2)
                    # Step 12.
301
                    T.gemm(SP0_shared, KV_shared_0_r, acc_o_r)
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

                    if k < loop_range - 1:

                        T.copy(
                            KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head,
                               h_dim:], KV_shared_0_r)
                        T.barrier_arrive(kv_shared_0_r_is_ready)

                        T.copy(
                            K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :],
                            K_pe_shared_0)
                        T.barrier_arrive(kv_shared_0_pe_is_ready)

                T.barrier_wait(lse_0_ready_barrier, 0)
                for i in T.Parallel(block_H):
                    logsum[i] += logsum_1[i]
                T.barrier_arrive(lse_1_ready_barrier)
                for i, j in T.Parallel(block_H, h_dim):
                    acc_o_r[i, j] /= logsum[i]
                T.copy(acc_o_r, O_shared_r)
                T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
                                          h_dim:])
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

    @T.prim_func
    def main_no_split(
            Q: T.Tensor([batch, heads, dim], dtype),
            Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
            KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
            K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
            glse: T.Tensor([batch, heads, num_split], dtype),
            Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
            Output: T.Tensor([batch, heads, dim], dtype),
    ):
        flash_attn(Q, Q_pe, KV, K_pe, Output)

    return main_no_split


def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
    #     """
    #     Inputs:
    #     - q (Tensor): [batch, heads, dim]
    #     - q_pe (Tensor): [batch, heads, pe_dim]
    #     - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
    #     - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
    #     - glse (Tensor): [batch, heads, num_split]
    #     - Output_partial (Tensor): [batch, heads, num_split, dim]
    #     Outputs:
    #     - output (Tensor): [batch, heads, dim]
    #     """
    dim = q.shape[-1]
    pe_dim = q_pe.shape[-1]
    num_head_groups = q.shape[1] // kv.shape[2]
    scale = (dim + pe_dim)**0.5
    q = rearrange(
        q, 'b (h g) d -> b g h d', g=num_head_groups)  # [batch_size, num_head_groups, groups, dim]

    q_pe = rearrange(
        q_pe, 'b (h g) d -> b g h d',
        g=num_head_groups)  # [batch_size, num_head_groups, groups, pe_dim]

    kv = rearrange(kv, 'b n h d -> b h n d')  # [batch_size, groups, seqlen_kv, dim]

    k_pe = rearrange(k_pe, 'b n h d -> b h n d')  # [batch_size, num_head_groups, groups, pe_dim]

    query = torch.concat([q, q_pe], dim=-1)
    key = torch.concat([kv, k_pe], dim=-1)

    scores = einsum(
        query, key,
        'b g h d, b h s d -> b g h s')  # [batch_size, num_head_groups, groups, seqlen_kv]

    attention = F.softmax(
        scores / scale, dim=-1)  # [batch_size, num_head_groups, groups, seqlen_kv]

    out = einsum(attention, kv,
                 'b g h s, b h s d -> b g h d')  # [batch_size, num_head_groups, groups, dim]
    out = rearrange(out, 'b g h d -> b (h g) d')  # [batch_size, heads, dim]
    return out


383
def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
384
385
386
387
388
389
390
    qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
    pv_flops = 2 * batch * heads * kv_ctx * dim
    total_flops = qk_flops + pv_flops
    BLOCK_N = 64
    BLOCK_H = 64
    num_split = 1

391
    kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
392
    print(kernel.get_kernel_source())
393
394
395
396
397
398
399
400
    profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
    profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
    latency = profiler.do_bench(warmup=500)
    print(f"Latency: {latency} ms")
    print(f"TFlops: {total_flops / latency * 1e-9} TFlops")


if __name__ == "__main__":
401
    parser = argparse.ArgumentParser()
402
    parser.add_argument('--batch', type=int, default=1, help='batch size')
403
404
405
406
407
408
409
410
    parser.add_argument('--heads', type=int, default=128, help='q heads number')
    parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
    parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
    parser.add_argument('--dim', type=int, default=512, help='head dim')
    parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
    args = parser.parse_args()
    batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
    main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)