example_mha_inference.py 14.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from functools import partial

num_split = 4


11
@tilelang.jit(out_idx=[5])
12
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
13
    scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
14
15
16
17
18
19
20
21
    shape_q = [batch, seqlen_q, heads, dim]
    shape_kv = [batch, seqlen_kv, heads, dim]
    part_shape = [batch, seqlen_q, heads, num_split, dim]
    dtype = "float16"
    accum_dtype = "float"

    @T.macro
    def MMA0(
22
23
24
25
        K: T.Tensor(shape_kv, dtype),
        Q_shared: T.SharedBuffer([block_M, dim], dtype),
        K_shared: T.SharedBuffer([block_N, dim], dtype),
        acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
26
27
28
29
30
31
        k: T.int32,
        mid: T.int32,
        hid: T.int32,
        bid: T.int32,
        sid: T.int32,
    ):
32
        T.copy(K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], K_shared)
33
34
        # TODO: Handle causal split case
        if is_causal:
35
            for i, j in T.Parallel(block_M, block_N):
36
                acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
37
38
39
40
41
42
        else:
            T.clear(acc_s)
        T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

    @T.macro
    def MMA1(
43
        V: T.Tensor(shape_kv, dtype),
44
        V_shared: T.SharedBuffer([block_N, dim], dtype),
45
46
        acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
        acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
47
48
49
50
51
        k: T.int32,
        hid: T.int32,
        bid: T.int32,
        sid: T.int32,
    ):
52
        T.copy(V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], V_shared)
53
54
55
56
        T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

    @T.macro
    def Softmax(
57
58
59
60
61
62
63
        acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
        acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
        scores_max: T.FragmentBuffer([block_M], accum_dtype),
        scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
        scores_scale: T.FragmentBuffer([block_M], accum_dtype),
        scores_sum: T.FragmentBuffer([block_M], accum_dtype),
        logsum: T.FragmentBuffer([block_M], accum_dtype),
64
65
66
67
    ):
        T.copy(scores_max, scores_max_prev)
        T.fill(scores_max, -T.infinity(accum_dtype))
        T.reduce_max(acc_s, scores_max, dim=1, clear=False)
68
69
        for i in T.Parallel(block_M):
            scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        # To do causal softmax, we need to set the scores_max to 0 if it is -inf
        # This process is called Check_inf in FlashAttention3 code, and it only need to be done
        # in the first ceil_div(kBlockM, kBlockN) steps.
        # for i in T.Parallel(block_M):
        #     scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
        for i in T.Parallel(block_M):
            scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
        for i, j in T.Parallel(block_M, block_N):
            # Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
            # max * log_2(e)) This allows the compiler to use the ffma
            # instruction instead of fadd and fmul separately.
            acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
        T.reduce_sum(acc_s, scores_sum, dim=1)
        for i in T.Parallel(block_M):
            logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
        T.copy(acc_s, acc_s_cast)

    @T.macro
    def Rescale(
89
90
        acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
        scores_scale: T.FragmentBuffer([block_M], accum_dtype),
91
92
93
94
95
96
    ):
        for i, j in T.Parallel(block_M, dim):
            acc_o[i, j] *= scores_scale[i]

    @T.macro
    def flash_attn_split(
97
98
99
100
101
        Q: T.Tensor(shape_q, dtype),
        K: T.Tensor(shape_kv, dtype),
        V: T.Tensor(shape_kv, dtype),
        glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
        Output_partial: T.Tensor(part_shape, dtype),
102
    ):
103
        with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz):
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
            Q_shared = T.alloc_shared([block_M, dim], dtype)
            K_shared = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_N, dim], dtype)
            O_shared = T.alloc_shared([block_M, dim], dtype)
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)

            mid = bx
            hid = by % heads
            bid = by // heads
            sid = bz

122
123
            # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
            # disable relevant tma copy and use SIMT as fallback for now
124
            T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True)
125
126
127
128
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

129
            # TODO: Handle causal split case
130
            loop_range = (
131
132
133
134
                T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N))
                if is_causal
                else T.ceildiv((seqlen_kv // num_split), block_N)
            )
135
136
137

            for k in T.Pipelined(loop_range, num_stages=2):
                MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid)
138
                Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
139
140
141
142
143
144
                Rescale(acc_o, scores_scale)
                MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid)
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] /= logsum[i]
            for i in T.Parallel(block_M):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
145
            T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M])
146
            T.copy(acc_o, O_shared)
147
            T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True)
148
149
150

    @T.macro
    def combine(
151
152
153
        glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
        Output_partial: T.Tensor(part_shape, dtype),
        Output: T.Tensor(shape_q, dtype),
154
155
156
157
158
159
160
161
162
163
164
165
    ):
        with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz):
            po_local = T.alloc_fragment([block_M, dim], dtype)
            po_shared = T.alloc_shared([block_M, dim], dtype)
            o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype)
            o_shared = T.alloc_shared([block_M, dim], dtype)
            lse_local = T.alloc_fragment([num_split, block_M], dtype)
            lse_local_split = T.alloc_fragment([block_M], accum_dtype)
            lse_logsum_local = T.alloc_fragment([block_M], accum_dtype)
            lse_max_local = T.alloc_fragment([block_M], accum_dtype)
            scale_local = T.alloc_fragment([block_M], accum_dtype)

166
167
168
169
170
171
172
            T.annotate_layout(
                {
                    o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i),
                    o_shared: tilelang.layout.make_swizzled_layout(o_shared),
                    po_shared: tilelang.layout.make_swizzled_layout(po_shared),
                }
            )
173
174
175

            T.clear(lse_logsum_local)
            T.clear(o_accum_local)
176
177
178
179
180
181
182
183
184
            T.copy(
                glse[
                    bz,
                    by,
                    :,
                    bx * block_M : (bx + 1) * block_M,
                ],
                lse_local,
            )
185
186
187
188
189
190
191
192
            T.reduce_max(lse_local, lse_max_local, dim=0, clear=False)
            for k in T.Pipelined(num_split):
                T.copy(lse_local[k, :], lse_local_split)
                for i in T.Parallel(block_M):
                    lse_logsum_local[i] += T.exp2(lse_local_split[i] - lse_max_local[i])
            for i in T.Parallel(block_M):
                lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
            for k in T.Pipelined(num_split, num_stages=2):
193
                T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_shared, disable_tma=True)
194
                T.copy(po_shared, po_local)
195
196
                for i in T.Parallel(block_M):
                    lse_local_split[i] = lse_local[k, i]
197
198
199
200
201
                for i in T.Parallel(block_M):
                    scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i])
                for i, j in T.Parallel(block_M, dim):
                    o_accum_local[i, j] += po_local[i, j] * scale_local[i]
            T.copy(o_accum_local, o_shared)
202
            T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True)
203
204

    @T.prim_func
205
    def flashattn_mha_inference(
206
207
208
209
210
211
        Q: T.Tensor(shape_q, dtype),
        K: T.Tensor(shape_kv, dtype),
        V: T.Tensor(shape_kv, dtype),
        glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
        Output_partial: T.Tensor(part_shape, dtype),  # [batch, seqlen_q, heads, num_split, dim]
        Output: T.Tensor(shape_q, dtype),
212
213
214
215
    ):
        flash_attn_split(Q, K, V, glse, Output_partial)
        combine(glse, Output_partial, Output)

216
    return flashattn_mha_inference
217
218


219
220
def ref_program(Q, K, V, glse, Output_partial, causal):
    assert causal is False
221
    dim = Q.size(-1)
222
    scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
223
224
    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
    attention_weights = F.softmax(scores, dim=-1)
225
    output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
226
227
228
    return output


229
def reduce_ref(Q, K, V, glse, Output_partial, causal):
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0)
    lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0)  # [batch, seqlen_q, heads]
    lse_max = glse.max(dim=2, keepdim=False).values
    for ks in range(num_split):
        lse = glse[:, :, ks, :]
        lse_logsum += torch.exp2(lse - lse_max)
    lse_logsum = torch.log2(lse_logsum) + lse_max
    for ks in range(num_split):
        lse = glse[:, :, ks, :]
        scale = torch.exp2(lse - lse_logsum)  # [batch, heads, seqlen_q]
        o += Output_partial[:, :, :, ks, :] * scale[:, :, :, None].transpose(1, 2)
    return o.to(torch.float16)


244
def flash_split_ref(Q, K, V, causal):
245
246
247
248
249
250
251
252
    # [batch, seqlen_q, heads, dim]
    batch = Q.size(0)
    block_M = Q.size(1)
    nheads = Q.size(2)
    dim = Q.size(3)
    block_N = 128
    seqlen_kv = K.size(1)

253
    scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float)
    acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16)
    acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
    scores_max = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
    scores_max_prev = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
    scores_scale = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
    scores_sum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
    logsum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
    gacc_o = torch.empty((num_split, batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
    glogsum = torch.empty((num_split, batch, nheads, block_M), device="cuda", dtype=torch.float)

    Q_ = Q * scale

    for ks in range(num_split):
        acc_o.fill_(0)
        logsum.fill_(0)
270
271
        scores_max.fill_(float("-inf"))
        scores_max_prev.fill_(float("-inf"))
272
273
        for i in range(int((seqlen_kv // num_split) / block_N)):
            acc_s.fill_(0)
274
275
276
277
278
            acc_s = torch.einsum(
                "bqhd,bkhd->bhqk",
                Q_,
                K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
            )  # [batch, seqlen, nheads, block_N]
279
280
281
282
283
284
285
            scores_max_prev = scores_max
            scores_max = acc_s.max(dim=-1, keepdim=False).values  # [blockM]
            scores_scale = torch.exp2(scores_max_prev - scores_max)
            acc_o *= scores_scale[:, :, :, None].transpose(1, 2)
            acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
            acc_s_cast = acc_s.to(torch.float16)
            acc_o += torch.einsum(
286
287
288
289
                "bhqk,bkhd->bqhd",
                acc_s_cast,
                V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
            )
290
291
292
293
294
295
296
            scores_sum = acc_s.sum(dim=-1, keepdim=False)
            logsum = logsum * scores_scale + scores_sum
        acc_o /= logsum[:, :, :, None].transpose(1, 2)
        logsum = torch.log2(logsum) + scores_max
        gacc_o[ks, :, :, :, :] = acc_o
        glogsum[ks, :, :, :] = logsum

297
    return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
298
299


300
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
301
302
    flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
    total_flops = 2 * flops_per_matmul
303
    if causal:
304
305
306
        total_flops *= 0.5
    BLOCK_M = 128
    BLOCK_N = 64  # if D_HEAD <= 128 else 32
307
    kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
308
    ref_fn = partial(ref_program, causal=causal)
309
    profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
310
    profiler.assert_allclose(ref_fn, rtol=0.01, atol=0.01)
311
312
    print("All checks passed!")

313
    latency = profiler.do_bench(ref_fn, warmup=500)
314
315
    print("{:.2f} ms".format(latency))
    print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
316
    latency = profiler.do_bench(n_warmup=10, n_repeat=10)
317
318
    print("{:.4f} ms".format(latency))
    print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
319
320
321


if __name__ == "__main__":
322
    main()