example_mha_fwd_bhsd.py 9.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial


def get_configs():
    block_M = [128]
    block_N = [128]
    num_stages = [2]
    threads = [256]
    _configs = list(itertools.product(block_M, block_N, num_stages, threads))

    configs = [{
        'block_M': c[0],
        'block_N': c[1],
        'num_stages': c[2],
        'threads': c[3]
    } for c in _configs]
    return configs


27
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
28
    scale = (1.0 / dim)**0.5 * 1.44269504  # log2(e)
29
30
    q_shape = [batch, heads, seq_q, dim]
    kv_shape = [batch, heads, seq_kv, dim]
31
32
33
34
35
36
37
    dtype = "float16"
    accum_dtype = "float"

    def kernel_func(block_M, block_N, num_stages, threads):

        @T.macro
        def MMA0(
38
            K: T.Buffer(kv_shape, dtype),
39
40
41
42
43
44
45
46
            Q_shared: T.Buffer([block_M, dim], dtype),
            K_shared: T.Buffer([block_N, dim], dtype),
            acc_s: T.Buffer([block_M, block_N], accum_dtype),
            k: T.int32,
            bx: T.int32,
            by: T.int32,
            bz: T.int32,
        ):
47
            past_len = seq_kv - seq_q
48
            T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
49
            if is_causal:
50
                for i, j in T.Parallel(block_M, block_N):
51
52
53
                    q_idx = bx * block_M + i + past_len
                    k_idx = k * block_N + j
                    acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
54
55
56
57
58
59
            else:
                T.clear(acc_s)
            T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

        @T.macro
        def MMA1(
60
                V: T.Buffer(kv_shape, dtype),
61
62
63
64
65
66
67
                V_shared: T.Buffer([block_M, dim], dtype),
                acc_s_cast: T.Buffer([block_M, block_N], dtype),
                acc_o: T.Buffer([block_M, dim], accum_dtype),
                k: T.int32,
                by: T.int32,
                bz: T.int32,
        ):
68
            T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
            T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

        @T.macro
        def Softmax(
                acc_s: T.Buffer([block_M, block_N], accum_dtype),
                acc_s_cast: T.Buffer([block_M, block_N], dtype),
                scores_max: T.Buffer([block_M], accum_dtype),
                scores_max_prev: T.Buffer([block_M], accum_dtype),
                scores_scale: T.Buffer([block_M], accum_dtype),
                scores_sum: T.Buffer([block_M], accum_dtype),
                logsum: T.Buffer([block_M], accum_dtype),
        ):
            T.copy(scores_max, scores_max_prev)
            T.fill(scores_max, -T.infinity(accum_dtype))
            T.reduce_max(acc_s, scores_max, dim=1, clear=False)
            # To do causal softmax, we need to set the scores_max to 0 if it is -inf
            # This process is called Check_inf in FlashAttention3 code, and it only need to be done
            # in the first ceil_div(kBlockM, kBlockN) steps.
            # for i in T.Parallel(block_M):
            #     scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
            for i in T.Parallel(block_M):
                scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
            for i, j in T.Parallel(block_M, block_N):
                # Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
                # max * log_2(e)) This allows the compiler to use the ffma
                # instruction instead of fadd and fmul separately.
                acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
            T.reduce_sum(acc_s, scores_sum, dim=1)
            for i in T.Parallel(block_M):
                logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
            T.copy(acc_s, acc_s_cast)

        @T.macro
        def Rescale(
                acc_o: T.Buffer([block_M, dim], accum_dtype),
                scores_scale: T.Buffer([block_M], accum_dtype),
        ):
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] *= scores_scale[i]

        @T.prim_func
        def main(
112
113
114
115
                Q: T.Buffer(q_shape, dtype),
                K: T.Buffer(kv_shape, dtype),
                V: T.Buffer(kv_shape, dtype),
                Output: T.Buffer(q_shape, dtype),
116
        ):
117
            with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
118
119
120
121
122
123
124
125
126
127
128
129
130
                Q_shared = T.alloc_shared([block_M, dim], dtype)
                K_shared = T.alloc_shared([block_N, dim], dtype)
                V_shared = T.alloc_shared([block_N, dim], dtype)
                O_shared = T.alloc_shared([block_M, dim], dtype)
                acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
                acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
                acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
                scores_max = T.alloc_fragment([block_M], accum_dtype)
                scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
                scores_scale = T.alloc_fragment([block_M], accum_dtype)
                scores_sum = T.alloc_fragment([block_M], accum_dtype)
                logsum = T.alloc_fragment([block_M], accum_dtype)

131
                T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
132
133
134
135
136
                T.fill(acc_o, 0)
                T.fill(logsum, 0)
                T.fill(scores_max, -T.infinity(accum_dtype))

                loop_range = (
137
138
                    T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
                        (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
139
140
141
142
143
144
145
146
147
148

                for k in T.Pipelined(loop_range, num_stages=num_stages):
                    MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
                    Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
                            scores_sum, logsum)
                    Rescale(acc_o, scores_scale)
                    MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
                for i, j in T.Parallel(block_M, dim):
                    acc_o[i, j] /= logsum[i]
                T.copy(acc_o, O_shared)
149
                T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
150
151
152
153
154
155
156
157
158
159

        return main

    if tune:

        @autotune(
            configs=get_configs(),
            keys=["block_M", "block_N", "num_stages", "threads"],
            warmup=10,
            rep=10)
160
        @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
161
162
163
164
165
166
167
168
169
170
171
172
        def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
            return kernel_func(block_M, block_N, num_stages, threads)

        return kernel()
    else:

        def kernel(block_M, block_N, num_stages, threads):
            return kernel_func(block_M, block_N, num_stages, threads)

        return kernel


173
def ref_program(Q, K, V, is_causal):
174
    dim = Q.size(-1)
175
    scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
176
    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
177
    if is_causal:
178
179
180
        seq_q = Q.size(2)
        seq_kv = K.size(2)
        mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device))
181
182
183
        mask = mask.unsqueeze(0).unsqueeze(0)
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attention_weights = F.softmax(scores, dim=-1)
184
    output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
185
186
187
188
189
    return output


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
190
191
192
193
194
    parser.add_argument('--batch', type=int, default=1, help='batch size')
    parser.add_argument('--heads', type=int, default=1, help='heads')
    parser.add_argument('--seq_q', type=int, default=256, help='query sequence length')
    parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length')
    parser.add_argument('--dim', type=int, default=64, help='dim')
195
    parser.add_argument('--is_causal', action='store_true', help='causal')
196
197
    parser.add_argument('--tune', action='store_true', help='tune configs')
    args = parser.parse_args()
198
199
    batch, heads, seq_q, seq_kv, dim, is_causal = args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal
    flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
200
    total_flops = 2 * flops_per_matmul
201
    if is_causal:
202
203
204
205
        total_flops *= 0.5

    if (not args.tune):
        program = flashattn(
206
            batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)(
207
                block_M=64, block_N=64, num_stages=1, threads=128)
208
        ref_program = partial(ref_program, is_causal=is_causal)
209
        kernel = tilelang.compile(program, out_idx=[3])
210

211
212
        profiler = kernel.get_profiler()
        profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
213
        print("All checks pass.")
214
        latency = profiler.do_bench(ref_program, warmup=500)
215
216
        print("Ref: {:.2f} ms".format(latency))
        print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
217
        latency = profiler.do_bench(warmup=500)
218
219
220
221
        print("Tile-lang: {:.2f} ms".format(latency))
        print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    else:
        best_latency, best_config, _ = flashattn(
222
            batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)
223
224
225
        print(f"Best latency: {best_latency}")
        print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
        print(f"Best config: {best_config}")