example_mla_decode.py 14.7 KB
Newer Older
1
2
3
4
5
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
6
from einops import rearrange, einsum
7
import argparse
8
9


10
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
11
12
13
14
15
16
17
    scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504  # log2(e)
    dtype = "float16"
    accum_dtype = "float"
    kv_group_num = heads // kv_head_num
    VALID_BLOCK_H = min(block_H, kv_group_num)
    assert kv_head_num == 1, "kv_head_num must be 1"

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    @T.macro
    def flash_attn(
            Q: T.Buffer([batch, heads, dim], dtype),
            Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
            KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
            K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
            Output: T.Buffer([batch, heads, dim], dtype),
    ):
        with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by):
            Q_shared = T.alloc_shared([block_H, dim], dtype)
            S_shared = T.alloc_shared([block_H, block_N], dtype)
            Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
            KV_shared = T.alloc_shared([block_N, dim], dtype)
            K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
            O_shared = T.alloc_shared([block_H, dim], dtype)
            acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
            acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
            acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
            scores_scale = T.alloc_fragment([block_H], accum_dtype)
            scores_sum = T.alloc_fragment([block_H], accum_dtype)
            logsum = T.alloc_fragment([block_H], accum_dtype)

            bid = bx
            hid = by
            cur_kv_head = hid // (kv_group_num // block_H)

            T.use_swizzle(10)

            T.annotate_layout({
                O_shared: tilelang.layout.make_swizzled_layout(O_shared),
                S_shared: tilelang.layout.make_swizzled_layout(S_shared),
            })

            T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
            T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            loop_range = T.ceildiv(seqlen_kv, block_N)
            for k in T.Pipelined(loop_range, num_stages=2):
                kv_start = k * block_N
                kv_end = (k + 1) * block_N

                T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
                T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)

                T.clear(acc_s_0)
                T.gemm(
                    Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
                T.gemm(
                    Q_pe_shared,
                    K_pe_shared,
                    acc_s_0,
                    transpose_B=True,
                    policy=T.GemmWarpPolicy.FullCol)
                T.copy(scores_max, scores_max_prev)
                T.fill(scores_max, -T.infinity(accum_dtype))
                T.copy(acc_s_0, S_shared)
                T.copy(S_shared, acc_s)
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
                for i in T.Parallel(block_H):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_H, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_H):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                T.copy(acc_s, acc_s_cast)
                for i, j in T.Parallel(block_H, dim):
                    acc_o[i, j] *= scores_scale[i]
                T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
            for i, j in T.Parallel(block_H, dim):
                acc_o[i, j] /= logsum[i]

            T.copy(acc_o, O_shared)
            T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :])

99
100
    @T.macro
    def flash_attn_split(
101
102
103
104
            Q: T.Buffer([batch, heads, dim], dtype),
            Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
            KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
            K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
105
            glse: T.Buffer([batch, heads, num_split], dtype),
106
            Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
107
108
    ):
        with T.Kernel(
109
110
111
112
113
114
                batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_H, dim], dtype)
            S_shared = T.alloc_shared([block_H, block_N], dtype)
            Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
            KV_shared = T.alloc_shared([block_N, dim], dtype)
            K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
115
116
            O_shared = T.alloc_shared([block_H, dim], dtype)
            acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
117
            acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
118
119
120
121
122
123
124
125
126
127
128
129
130
            acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
            acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
            scores_scale = T.alloc_fragment([block_H], accum_dtype)
            scores_sum = T.alloc_fragment([block_H], accum_dtype)
            logsum = T.alloc_fragment([block_H], accum_dtype)

            bid = bx
            hid = by
            sid = bz
            cur_kv_head = hid // (kv_group_num // block_H)

131
132
            T.use_swizzle(10)

133
134
            T.annotate_layout({
                O_shared: tilelang.layout.make_swizzled_layout(O_shared),
135
                S_shared: tilelang.layout.make_swizzled_layout(S_shared),
136
137
138
            })

            T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
139
            T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
140
141
142
143
144
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            for k in T.Pipelined(loop_range, num_stages=2):
                kv_start = (seqlen_kv // num_split) * sid + k * block_N
                kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N

                T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
                T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)

                T.clear(acc_s_0)
                T.gemm(
                    Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
                T.gemm(
                    Q_pe_shared,
                    K_pe_shared,
                    acc_s_0,
                    transpose_B=True,
                    policy=T.GemmWarpPolicy.FullCol)
161
162
                T.copy(scores_max, scores_max_prev)
                T.fill(scores_max, -T.infinity(accum_dtype))
163
164
                T.copy(acc_s_0, S_shared)
                T.copy(S_shared, acc_s)
165
166
167
168
169
170
171
172
173
174
175
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
                for i in T.Parallel(block_H):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_H, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_H):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                T.copy(acc_s, acc_s_cast)
                for i, j in T.Parallel(block_H, dim):
                    acc_o[i, j] *= scores_scale[i]
176
                T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
177
178
179
180
181
182
183
184
185
186
187
188
189
            for i, j in T.Parallel(block_H, dim):
                acc_o[i, j] /= logsum[i]
            for i in T.Parallel(block_H):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale

            T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid])
            T.copy(acc_o, O_shared)
            T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
                                            sid, :])

    @T.macro
    def combine(
            glse: T.Buffer([batch, heads, num_split], dtype),
190
191
            Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
            Output: T.Buffer([batch, heads, dim], dtype),
192
193
194
195
196
197
    ):
        with T.Kernel(heads, batch, threads=128) as (by, bz):
            po_local = T.alloc_fragment([dim], dtype)
            o_accum_local = T.alloc_fragment([dim], accum_dtype)
            lse_local_split = T.alloc_local([1], accum_dtype)
            lse_logsum_local = T.alloc_local([1], accum_dtype)
198
            lse_max_local = T.alloc_local([1], accum_dtype)
199
200
201
202
203
204
205
206
            scale_local = T.alloc_local([1], accum_dtype)

            T.annotate_layout({
                lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
            })

            T.clear(lse_logsum_local)
            T.clear(o_accum_local)
207
208
            for k in T.serial(num_split):
                lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
            for k in T.Pipelined(num_split, num_stages=1):
                lse_local_split[0] = glse[bz, by, k]
                lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
            lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
            for k in T.serial(num_split):
                for i in T.Parallel(dim):
                    po_local[i] = Output_partial[bz, by, k, i]
                lse_local_split[0] = glse[bz, by, k]
                scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
                for i in T.Parallel(dim):
                    o_accum_local[i] += po_local[i] * scale_local[0]
            for i in T.Parallel(dim):
                Output[bz, by, i] = o_accum_local[i]

    @T.prim_func
224
    def main_split(
225
226
227
228
            Q: T.Buffer([batch, heads, dim], dtype),
            Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
            KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
            K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
229
            glse: T.Buffer([batch, heads, num_split], dtype),
230
231
            Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
            Output: T.Buffer([batch, heads, dim], dtype),
232
    ):
233
        flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
234
235
        combine(glse, Output_partial, Output)

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    @T.prim_func
    def main_no_split(
            Q: T.Buffer([batch, heads, dim], dtype),
            Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
            KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
            K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
            glse: T.Buffer([batch, heads, num_split], dtype),
            Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
            Output: T.Buffer([batch, heads, dim], dtype),
    ):
        flash_attn(Q, Q_pe, KV, K_pe, Output)

    if num_split > 1:
        return main_split
    else:
        return main_no_split
252
253


254
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
255
256
    #     """
    #     Inputs:
257
258
259
260
261
262
    #     - q (Tensor): [batch, heads, dim]
    #     - q_pe (Tensor): [batch, heads, pe_dim]
    #     - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
    #     - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
    #     - glse (Tensor): [batch, heads, num_split]
    #     - Output_partial (Tensor): [batch, heads, num_split, dim]
263
264
265
    #     Outputs:
    #     - output (Tensor): [batch, heads, dim]
    #     """
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    dim = q.shape[-1]
    pe_dim = q_pe.shape[-1]
    num_head_groups = q.shape[1] // kv.shape[2]
    scale = (dim + pe_dim)**0.5
    q = rearrange(
        q, 'b (h g) d -> b g h d', g=num_head_groups)  # [batch_size, num_head_groups, groups, dim]

    q_pe = rearrange(
        q_pe, 'b (h g) d -> b g h d',
        g=num_head_groups)  # [batch_size, num_head_groups, groups, pe_dim]

    kv = rearrange(kv, 'b n h d -> b h n d')  # [batch_size, groups, seqlen_kv, dim]

    k_pe = rearrange(k_pe, 'b n h d -> b h n d')  # [batch_size, num_head_groups, groups, pe_dim]

    query = torch.concat([q, q_pe], dim=-1)
    key = torch.concat([kv, k_pe], dim=-1)

    scores = einsum(
        query, key,
        'b g h d, b h s d -> b g h s')  # [batch_size, num_head_groups, groups, seqlen_kv]

    attention = F.softmax(
        scores / scale, dim=-1)  # [batch_size, num_head_groups, groups, seqlen_kv]

    out = einsum(attention, kv,
                 'b g h s, b h s d -> b g h d')  # [batch_size, num_head_groups, groups, dim]
    out = rearrange(out, 'b g h d -> b (h g) d')  # [batch_size, heads, dim]
    return out
295
296
297


if __name__ == "__main__":
298
299
300
301
302
303
304
305
306
307
308
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, default=128, help='batch size')
    parser.add_argument('--heads', type=int, default=128, help='q heads number')
    parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
    parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
    parser.add_argument('--dim', type=int, default=512, help='head dim')
    parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
    args = parser.parse_args()
    batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
    qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
    pv_flops = 2 * batch * heads * kv_ctx * dim
309
    total_flops = qk_flops + pv_flops
310
    BLOCK_N = 64
311
    BLOCK_H = 64
312
    num_split = 1
313

314
    program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
315
    mod, params = tilelang.lower(program)
316
    mod = tilelang.Profiler(mod, params, [6], tilelang.TensorSupplyType.Randn)
317
    mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
318
319
    print("All close")
    latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch")
320
    print("Tile-lang: {:.2f} ms".format(latency))
321
    print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))