example_warp_specialize_flashmla.py 16.6 KB
Newer Older
1
2
3
import torch
import torch.nn.functional as F
import tilelang
4
from tilelang.autotuner import *
5
6
import tilelang.language as T
from einops import rearrange, einsum
7
8
import argparse

9

10
@tilelang.jit(out_idx=[6])
11
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
12
    scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504  # log2(e)
13
14
    dtype = T.float16
    accum_dtype = T.float32
15
16
17
    kv_group_num = heads // kv_head_num
    VALID_BLOCK_H = min(block_H, kv_group_num)
    assert kv_head_num == 1, "kv_head_num must be 1"
18
    h_dim = dim // 2
19
20
21

    @T.macro
    def flash_attn(
22
23
24
25
26
        Q: T.Tensor([batch, heads, dim], dtype),
        Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
        KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
        K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
        Output: T.Tensor([batch, heads, dim], dtype),
27
    ):
28
        with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
29
            # smem_sQ
30
31
            Q_shared_l = T.alloc_shared([block_H, h_dim], dtype)
            Q_shared_r = T.alloc_shared([block_H, h_dim], dtype)
32
            Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
33
34
35
36
            Q_pe_local_0 = T.alloc_fragment([block_H, pe_dim], dtype)
            Q_pe_local_1 = T.alloc_fragment([block_H, pe_dim], dtype)

            # smem_sK0
37
38
            KV_shared_0_l = T.alloc_shared([block_N, h_dim], dtype)
            KV_shared_0_r = T.alloc_shared([block_N, h_dim], dtype)
39
40
41
            K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)

            # smem_sK1
42
43
44
            KV_shared_1_l = T.alloc_shared([block_N, h_dim], dtype)
            KV_shared_1_r = T.alloc_shared([block_N, h_dim], dtype)
            K_pe_shared_1 = T.alloc_shared([block_N, pe_dim], dtype)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

            # smem_sP0
            SP0_shared = T.alloc_shared([block_H, block_N], dtype)

            # smem_sP1 reuse Q_pe_shared
            SP1_shared = Q_pe_shared

            # smem_sM
            scores_max = T.alloc_shared([block_H], accum_dtype)

            # smem_sScale0
            scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
            # smem_sScale1
            scores_scale_1 = T.alloc_shared([block_H], accum_dtype)

            logsum = T.alloc_shared([block_H], accum_dtype)

62
63
64
65
            O_shared_l = Q_shared_l
            O_shared_r = Q_shared_r

            acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
66
            acc_s_0_cast = T.alloc_fragment([block_H, block_N], dtype)
67
            acc_s_1 = T.alloc_fragment([block_H, block_N], accum_dtype)
68
            acc_s_1_cast = T.alloc_fragment([block_H, block_N], dtype)
69
70
71
72
            acc_o_l = T.alloc_fragment([block_H, h_dim], accum_dtype)
            acc_o_r = T.alloc_fragment([block_H, h_dim], accum_dtype)
            scores_max_0 = T.alloc_fragment([block_H], accum_dtype)
            scores_max_1 = T.alloc_fragment([block_H], accum_dtype)
73
74
75

            scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype)
76

77
78
79
80
81
82
83
            scores_sum_0 = T.alloc_fragment([block_H], accum_dtype)
            scores_sum_1 = T.alloc_fragment([block_H], accum_dtype)
            logsum_0 = T.alloc_fragment([block_H], accum_dtype)
            logsum_1 = T.alloc_fragment([block_H], accum_dtype)

            cur_kv_head = hid // (kv_group_num // block_H)

84
85
86
87
88
89
            T.annotate_layout(
                {
                    O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l),
                    O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
                }
            )
90

91
92
93
94
            # barriers_Q
            q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)

            # barriers_K0
95
96
97
            kv_shared_0_l_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_0_r_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_0_pe_is_ready = T.alloc_barrier(arrive_count=128)
98
            # barriers_K1
99
100
101
            kv_shared_1_l_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_1_r_is_ready = T.alloc_barrier(arrive_count=128)
            kv_shared_1_pe_is_ready = T.alloc_barrier(arrive_count=128)
102
103

            # redundant barriers
104
105
106
107
108
109
110
111
112
            score_max_0_ready_barrier = T.alloc_barrier(arrive_count=128)
            scale_1_ready_barrier = T.alloc_barrier(arrive_count=128)
            p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128)
            lse_0_ready_barrier = T.alloc_barrier(arrive_count=128)
            lse_1_ready_barrier = T.alloc_barrier(arrive_count=128)
            s_shared_ready_barrier = T.alloc_barrier(arrive_count=128)

            tx = T.get_thread_binding()

113
114
115
            T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l)
            T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r)
            T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
116
117
            T.barrier_arrive(q_shared_ready_barrier)
            T.barrier_wait(q_shared_ready_barrier, 0)
118

119
120
121
122
123
            T.fill(scores_max, -T.infinity(accum_dtype))

            loop_range = T.ceildiv(seqlen_kv, (block_N * 2))

            if tx < 128:
124
                T.copy(Q_pe_shared, Q_pe_local_0)
125
126
127
                T.fill(acc_o_l, 0)
                T.fill(logsum_0, 0)

128
                T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
129
130
                T.barrier_arrive(kv_shared_1_l_is_ready)

131
                T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
132
133
                T.barrier_arrive(kv_shared_1_r_is_ready)

134
                T.copy(K_pe[bid, block_N : 2 * block_N, cur_kv_head, :], K_pe_shared_1)
135
136
                T.barrier_arrive(kv_shared_1_pe_is_ready)

137
                for k in T.serial(loop_range):
138
                    T.barrier_wait(kv_shared_0_l_is_ready, k % 2)
139
                    T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, clear_accum=True, wg_wait=-1)
140
                    T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
141
                    T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1)
142
143

                    T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
144
                    T.gemm(Q_pe_local_0, K_pe_shared_0, acc_s_0, transpose_B=True, wg_wait=-1)
145
146
147
148
149
150
151
152
153
154
155
156
157
158

                    T.wait_wgmma(0)

                    # Step 3.
                    T.copy(scores_max, scores_max_0)
                    T.copy(scores_max_0, scores_max_prev_0)
                    T.fill(scores_max_0, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s_0, scores_max_0, dim=1, clear=False)
                    T.copy(scores_max_0, scores_max)

                    # Step 4.
                    for i, j in T.Parallel(block_H, block_N):
                        acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale)
                    for i in T.Parallel(block_H):
159
                        scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_max[i] * scale)
160
161
162
163

                    T.reduce_sum(acc_s_0, scores_sum_0, dim=1)

                    # Step 5.
164
                    T.copy(acc_s_0, acc_s_0_cast)
165
166
167
168
169
170
171
172

                    for i, j in T.Parallel(block_H, h_dim):
                        acc_o_l[i, j] *= scores_scale_0[i]

                    for i in T.Parallel(block_H):
                        logsum_0[i] = logsum_0[i] * scores_scale_0[i] + scores_sum_0[i]

                    # Step 6.
173
                    T.gemm(acc_s_0_cast, KV_shared_0_l, acc_o_l)
174
175
176
177
178
                    T.barrier_arrive(score_max_0_ready_barrier)

                    T.barrier_wait(scale_1_ready_barrier, k % 2)

                    if k < loop_range - 1:
179
                        T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :h_dim], KV_shared_0_l)
180
181
182
183
                        T.barrier_arrive(kv_shared_0_l_is_ready)

                    # Step 11.
                    for i, j in T.Parallel(block_H, block_N):
184
                        SP0_shared[i, j] = acc_s_0[i, j] * scores_scale_1[i]
185
186
187
188
189
190
191
192
193
194
195

                    T.barrier_arrive(p0_1_1_ready_barrier)

                    # Step 13.
                    for i, j in T.Parallel(block_H, h_dim):
                        acc_o_l[i, j] *= scores_scale_1[i]
                    for i in T.Parallel(block_H):
                        logsum_0[i] = logsum_0[i] * scores_scale_1[i]
                    T.barrier_wait(s_shared_ready_barrier, k % 2)

                    # Step 14.
196
                    T.gemm(SP1_shared, KV_shared_1_l, acc_o_l)
197
198

                    if k < loop_range - 1:
199
                        T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
200
201
                        T.barrier_arrive(kv_shared_1_l_is_ready)

202
                        T.copy(K_pe[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :], K_pe_shared_1)
203
204
205
206
207
208
209
210
                        T.barrier_arrive(kv_shared_1_pe_is_ready)

                T.copy(logsum_0, logsum)
                T.barrier_arrive(lse_0_ready_barrier)
                T.barrier_wait(lse_1_ready_barrier, 0)
                for i, j in T.Parallel(block_H, h_dim):
                    acc_o_l[i, j] /= logsum[i]
                T.copy(acc_o_l, O_shared_l)
211
                T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim])
212
213

            else:
214
                T.copy(Q_pe_shared, Q_pe_local_1)
215
216
217
218
219
220
221
222
223
224
                T.fill(acc_o_r, 0)
                T.fill(logsum_1, 0)

                T.copy(KV[bid, :block_N, cur_kv_head, :h_dim], KV_shared_0_l)
                T.barrier_arrive(kv_shared_0_l_is_ready)
                T.copy(KV[bid, :block_N, cur_kv_head, h_dim:], KV_shared_0_r)
                T.barrier_arrive(kv_shared_0_r_is_ready)
                T.copy(K_pe[bid, :block_N, cur_kv_head, :], K_pe_shared_0)
                T.barrier_arrive(kv_shared_0_pe_is_ready)

225
                for k in T.serial(loop_range):
226
227
                    # Step 2.
                    T.barrier_wait(kv_shared_1_l_is_ready, k % 2)
228
                    T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, clear_accum=True, wg_wait=-1)
229
230

                    T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
231
                    T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1)
232
233

                    T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
234
                    T.gemm(Q_pe_local_1, K_pe_shared_1, acc_s_1, transpose_B=True, wg_wait=-1)
235
236
237
238
239
240
241
242
243
244
245

                    T.wait_wgmma(0)

                    # Step 7.
                    T.barrier_wait(score_max_0_ready_barrier, k % 2)

                    T.copy(scores_max, scores_max_prev_1)
                    T.fill(scores_max_1, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s_1, scores_max_1, dim=1, clear=False)
                    T.copy(scores_max_1, scores_max)

246
                    for i in T.Parallel(block_H):
247
                        scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_max[i] * scale)
248
249

                    # Step 8.
250
                    for i, j in T.Parallel(block_H, block_N):
251
252
253
254
255
256
257
258
                        acc_s_1[i, j] = T.exp2(acc_s_1[i, j] * scale - scores_max[i] * scale)

                    # Step 9.
                    T.reduce_sum(acc_s_1, scores_sum_1, dim=1)

                    for i, j in T.Parallel(block_H, h_dim):
                        acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i])

259
                    for i in T.Parallel(block_H):
260
                        logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[i] + scores_sum_1[i]
261
262
263
264

                    T.barrier_arrive(scale_1_ready_barrier)

                    # Step 10. compute O1 with KV_shared_1_rd
265
                    T.copy(acc_s_1, acc_s_1_cast)
266
                    T.gemm(acc_s_1_cast, KV_shared_1_r, acc_o_r, wg_wait=-1)
267
268
                    T.copy(acc_s_1_cast, SP1_shared)
                    T.barrier_arrive(s_shared_ready_barrier)
269
270

                    if k < loop_range - 1:
271
                        T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
272
273
274
275
                        T.barrier_arrive(kv_shared_1_r_is_ready)

                    T.barrier_wait(p0_1_1_ready_barrier, k % 2)
                    # Step 12.
276
                    T.gemm(SP0_shared, KV_shared_0_r, acc_o_r)
277
278

                    if k < loop_range - 1:
279
                        T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, h_dim:], KV_shared_0_r)
280
281
                        T.barrier_arrive(kv_shared_0_r_is_ready)

282
                        T.copy(K_pe[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :], K_pe_shared_0)
283
284
285
286
287
288
289
290
291
                        T.barrier_arrive(kv_shared_0_pe_is_ready)

                T.barrier_wait(lse_0_ready_barrier, 0)
                for i in T.Parallel(block_H):
                    logsum[i] += logsum_1[i]
                T.barrier_arrive(lse_1_ready_barrier)
                for i, j in T.Parallel(block_H, h_dim):
                    acc_o_r[i, j] /= logsum[i]
                T.copy(acc_o_r, O_shared_r)
292
                T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:])
293
294
295

    @T.prim_func
    def main_no_split(
296
297
298
299
300
301
302
        Q: T.Tensor([batch, heads, dim], dtype),
        Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
        KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
        K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
        glse: T.Tensor([batch, heads, num_split], dtype),
        Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
        Output: T.Tensor([batch, heads, dim], dtype),
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    ):
        flash_attn(Q, Q_pe, KV, K_pe, Output)

    return main_no_split


def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
    #     """
    #     Inputs:
    #     - q (Tensor): [batch, heads, dim]
    #     - q_pe (Tensor): [batch, heads, pe_dim]
    #     - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
    #     - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
    #     - glse (Tensor): [batch, heads, num_split]
    #     - Output_partial (Tensor): [batch, heads, num_split, dim]
    #     Outputs:
    #     - output (Tensor): [batch, heads, dim]
    #     """
    dim = q.shape[-1]
    pe_dim = q_pe.shape[-1]
    num_head_groups = q.shape[1] // kv.shape[2]
324
325
    scale = (dim + pe_dim) ** 0.5
    q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups)  # [batch_size, num_head_groups, groups, dim]
326

327
    q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups)  # [batch_size, num_head_groups, groups, pe_dim]
328

329
    kv = rearrange(kv, "b n h d -> b h n d")  # [batch_size, groups, seqlen_kv, dim]
330

331
    k_pe = rearrange(k_pe, "b n h d -> b h n d")  # [batch_size, num_head_groups, groups, pe_dim]
332
333
334
335

    query = torch.concat([q, q_pe], dim=-1)
    key = torch.concat([kv, k_pe], dim=-1)

336
    scores = einsum(query, key, "b g h d, b h s d -> b g h s")  # [batch_size, num_head_groups, groups, seqlen_kv]
337

338
    attention = F.softmax(scores / scale, dim=-1)  # [batch_size, num_head_groups, groups, seqlen_kv]
339

340
341
    out = einsum(attention, kv, "b g h s, b h s d -> b g h d")  # [batch_size, num_head_groups, groups, dim]
    out = rearrange(out, "b g h d -> b (h g) d")  # [batch_size, heads, dim]
342
343
344
    return out


345
def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
346
347
348
349
350
351
352
    qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
    pv_flops = 2 * batch * heads * kv_ctx * dim
    total_flops = qk_flops + pv_flops
    BLOCK_N = 64
    BLOCK_H = 64
    num_split = 1

353
    kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
354
    print(kernel.get_kernel_source())
355
356
357
358
359
360
361
362
    profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
    profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
    latency = profiler.do_bench(warmup=500)
    print(f"Latency: {latency} ms")
    print(f"TFlops: {total_flops / latency * 1e-9} TFlops")


if __name__ == "__main__":
363
    parser = argparse.ArgumentParser()
364
365
366
367
368
369
    parser.add_argument("--batch", type=int, default=1, help="batch size")
    parser.add_argument("--heads", type=int, default=128, help="q heads number")
    parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
    parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
    parser.add_argument("--dim", type=int, default=512, help="head dim")
    parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
370
371
372
    args = parser.parse_args()
    batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
    main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)