example_mha_fwd_bhsd.py 9.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial


def get_configs():
    block_M = [128]
    block_N = [128]
    num_stages = [2]
    threads = [256]
    _configs = list(itertools.product(block_M, block_N, num_stages, threads))

    configs = [{
        'block_M': c[0],
        'block_N': c[1],
        'num_stages': c[2],
        'threads': c[3]
    } for c in _configs]
    return configs


30
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
31
    scale = (1.0 / dim)**0.5 * 1.44269504  # log2(e)
32
33
    q_shape = [batch, heads, seq_q, dim]
    kv_shape = [batch, heads, seq_kv, dim]
34
35
36
37
38
39
40
    dtype = "float16"
    accum_dtype = "float"

    def kernel_func(block_M, block_N, num_stages, threads):

        @T.macro
        def MMA0(
41
            K: T.Buffer(kv_shape, dtype),
42
43
44
45
46
47
48
49
            Q_shared: T.Buffer([block_M, dim], dtype),
            K_shared: T.Buffer([block_N, dim], dtype),
            acc_s: T.Buffer([block_M, block_N], accum_dtype),
            k: T.int32,
            bx: T.int32,
            by: T.int32,
            bz: T.int32,
        ):
50
            past_len = seq_kv - seq_q
51
            T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
52
            if is_causal:
53
                for i, j in T.Parallel(block_M, block_N):
54
55
56
                    q_idx = bx * block_M + i + past_len
                    k_idx = k * block_N + j
                    acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
57
58
59
60
61
62
            else:
                T.clear(acc_s)
            T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

        @T.macro
        def MMA1(
63
                V: T.Buffer(kv_shape, dtype),
64
65
66
67
68
69
70
                V_shared: T.Buffer([block_M, dim], dtype),
                acc_s_cast: T.Buffer([block_M, block_N], dtype),
                acc_o: T.Buffer([block_M, dim], accum_dtype),
                k: T.int32,
                by: T.int32,
                bz: T.int32,
        ):
71
            T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

        @T.macro
        def Softmax(
                acc_s: T.Buffer([block_M, block_N], accum_dtype),
                acc_s_cast: T.Buffer([block_M, block_N], dtype),
                scores_max: T.Buffer([block_M], accum_dtype),
                scores_max_prev: T.Buffer([block_M], accum_dtype),
                scores_scale: T.Buffer([block_M], accum_dtype),
                scores_sum: T.Buffer([block_M], accum_dtype),
                logsum: T.Buffer([block_M], accum_dtype),
        ):
            T.copy(scores_max, scores_max_prev)
            T.fill(scores_max, -T.infinity(accum_dtype))
            T.reduce_max(acc_s, scores_max, dim=1, clear=False)
            # To do causal softmax, we need to set the scores_max to 0 if it is -inf
            # This process is called Check_inf in FlashAttention3 code, and it only need to be done
            # in the first ceil_div(kBlockM, kBlockN) steps.
            # for i in T.Parallel(block_M):
            #     scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
            for i in T.Parallel(block_M):
                scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            for i, j in T.Parallel(block_M, block_N):
                # Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
                # max * log_2(e)) This allows the compiler to use the ffma
                # instruction instead of fadd and fmul separately.
                acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
            T.reduce_sum(acc_s, scores_sum, dim=1)
            for i in T.Parallel(block_M):
                logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
            T.copy(acc_s, acc_s_cast)

        @T.macro
        def Rescale(
                acc_o: T.Buffer([block_M, dim], accum_dtype),
                scores_scale: T.Buffer([block_M], accum_dtype),
        ):
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] *= scores_scale[i]

        @T.prim_func
        def main(
115
116
117
118
                Q: T.Buffer(q_shape, dtype),
                K: T.Buffer(kv_shape, dtype),
                V: T.Buffer(kv_shape, dtype),
                Output: T.Buffer(q_shape, dtype),
119
        ):
120
            with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
121
122
123
124
125
126
127
128
129
130
131
132
133
                Q_shared = T.alloc_shared([block_M, dim], dtype)
                K_shared = T.alloc_shared([block_N, dim], dtype)
                V_shared = T.alloc_shared([block_N, dim], dtype)
                O_shared = T.alloc_shared([block_M, dim], dtype)
                acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
                acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
                acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
                scores_max = T.alloc_fragment([block_M], accum_dtype)
                scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
                scores_scale = T.alloc_fragment([block_M], accum_dtype)
                scores_sum = T.alloc_fragment([block_M], accum_dtype)
                logsum = T.alloc_fragment([block_M], accum_dtype)

134
                T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
135
136
137
138
139
                T.fill(acc_o, 0)
                T.fill(logsum, 0)
                T.fill(scores_max, -T.infinity(accum_dtype))

                loop_range = (
140
141
                    T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
                        (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
142
143
144
145
146
147
148
149
150
151

                for k in T.Pipelined(loop_range, num_stages=num_stages):
                    MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
                    Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
                            scores_sum, logsum)
                    Rescale(acc_o, scores_scale)
                    MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
                for i, j in T.Parallel(block_M, dim):
                    acc_o[i, j] /= logsum[i]
                T.copy(acc_o, O_shared)
152
                T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

        return main

    if tune:

        @autotune(
            configs=get_configs(),
            keys=["block_M", "block_N", "num_stages", "threads"],
            warmup=10,
            rep=10)
        @jit(
            out_idx=[3],
            supply_type=tilelang.TensorSupplyType.Integer,
            ref_prog=None,
            profiler="auto")
        def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
            return kernel_func(block_M, block_N, num_stages, threads)

        return kernel()
    else:

        def kernel(block_M, block_N, num_stages, threads):
            return kernel_func(block_M, block_N, num_stages, threads)

        return kernel


180
def ref_program(Q, K, V, is_causal):
181
    dim = Q.size(-1)
182
    scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
183
    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
184
    if is_causal:
185
186
187
        seq_q = Q.size(2)
        seq_kv = K.size(2)
        mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device))
188
189
190
        mask = mask.unsqueeze(0).unsqueeze(0)
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attention_weights = F.softmax(scores, dim=-1)
191
    output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
192
193
194
195
196
    return output


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
197
198
199
200
201
    parser.add_argument('--batch', type=int, default=1, help='batch size')
    parser.add_argument('--heads', type=int, default=1, help='heads')
    parser.add_argument('--seq_q', type=int, default=256, help='query sequence length')
    parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length')
    parser.add_argument('--dim', type=int, default=64, help='dim')
202
    parser.add_argument('--is_causal', action='store_true', help='causal')
203
204
    parser.add_argument('--tune', action='store_true', help='tune configs')
    args = parser.parse_args()
205
206
    batch, heads, seq_q, seq_kv, dim, is_causal = args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal
    flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
207
    total_flops = 2 * flops_per_matmul
208
    if is_causal:
209
210
211
212
        total_flops *= 0.5

    if (not args.tune):
        program = flashattn(
213
            batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)(
214
                block_M=64, block_N=64, num_stages=1, threads=128)
215
        ref_program = partial(ref_program, is_causal=is_causal)
216
        kernel = tilelang.compile(program, out_idx=[3])
217

218
219
        profiler = kernel.get_profiler()
        profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
220
        print("All checks pass.")
221
        latency = profiler.do_bench(ref_program, warmup=500)
222
223
        print("Ref: {:.2f} ms".format(latency))
        print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
224
        latency = profiler.do_bench(warmup=500)
225
226
227
228
        print("Tile-lang: {:.2f} ms".format(latency))
        print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    else:
        best_latency, best_config, _ = flashattn(
229
            batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)
230
231
232
        print(f"Best latency: {best_latency}")
        print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
        print(f"Best config: {best_config}")