example_mha_fwd_bhsd.py 9.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn.functional as F
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial


def get_configs():
    block_M = [128]
    block_N = [128]
    num_stages = [2]
    threads = [256]
    _configs = list(itertools.product(block_M, block_N, num_stages, threads))

    configs = [{
        'block_M': c[0],
        'block_N': c[1],
        'num_stages': c[2],
        'threads': c[3]
    } for c in _configs]
    return configs


31
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
32
    scale = (1.0 / dim)**0.5 * 1.44269504  # log2(e)
33
    shape = [batch, heads, seq_len, dim]
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    dtype = "float16"
    accum_dtype = "float"

    def kernel_func(block_M, block_N, num_stages, threads):

        @T.macro
        def MMA0(
            K: T.Buffer(shape, dtype),
            Q_shared: T.Buffer([block_M, dim], dtype),
            K_shared: T.Buffer([block_N, dim], dtype),
            acc_s: T.Buffer([block_M, block_N], accum_dtype),
            k: T.int32,
            bx: T.int32,
            by: T.int32,
            bz: T.int32,
        ):
50
            T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
51
            if is_causal:
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
                                                 -T.infinity(acc_s.dtype))
            else:
                T.clear(acc_s)
            T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

        @T.macro
        def MMA1(
                V: T.Buffer(shape, dtype),
                V_shared: T.Buffer([block_M, dim], dtype),
                acc_s_cast: T.Buffer([block_M, block_N], dtype),
                acc_o: T.Buffer([block_M, dim], accum_dtype),
                k: T.int32,
                by: T.int32,
                bz: T.int32,
        ):
69
            T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

        @T.macro
        def Softmax(
                acc_s: T.Buffer([block_M, block_N], accum_dtype),
                acc_s_cast: T.Buffer([block_M, block_N], dtype),
                scores_max: T.Buffer([block_M], accum_dtype),
                scores_max_prev: T.Buffer([block_M], accum_dtype),
                scores_scale: T.Buffer([block_M], accum_dtype),
                scores_sum: T.Buffer([block_M], accum_dtype),
                logsum: T.Buffer([block_M], accum_dtype),
        ):
            T.copy(scores_max, scores_max_prev)
            T.fill(scores_max, -T.infinity(accum_dtype))
            T.reduce_max(acc_s, scores_max, dim=1, clear=False)
            # To do causal softmax, we need to set the scores_max to 0 if it is -inf
            # This process is called Check_inf in FlashAttention3 code, and it only need to be done
            # in the first ceil_div(kBlockM, kBlockN) steps.
            # for i in T.Parallel(block_M):
            #     scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
            for i in T.Parallel(block_M):
                scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
            for i, j in T.Parallel(block_M, block_N):
                # Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
                # max * log_2(e)) This allows the compiler to use the ffma
                # instruction instead of fadd and fmul separately.
                acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
            T.reduce_sum(acc_s, scores_sum, dim=1)
            for i in T.Parallel(block_M):
                logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
            T.copy(acc_s, acc_s_cast)

        @T.macro
        def Rescale(
                acc_o: T.Buffer([block_M, dim], accum_dtype),
                scores_scale: T.Buffer([block_M], accum_dtype),
        ):
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] *= scores_scale[i]

        @T.prim_func
        def main(
                Q: T.Buffer(shape, dtype),
                K: T.Buffer(shape, dtype),
                V: T.Buffer(shape, dtype),
                Output: T.Buffer(shape, dtype),
        ):
            with T.Kernel(
                    T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
                Q_shared = T.alloc_shared([block_M, dim], dtype)
                K_shared = T.alloc_shared([block_N, dim], dtype)
                V_shared = T.alloc_shared([block_N, dim], dtype)
                O_shared = T.alloc_shared([block_M, dim], dtype)
                acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
                acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
                acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
                scores_max = T.alloc_fragment([block_M], accum_dtype)
                scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
                scores_scale = T.alloc_fragment([block_M], accum_dtype)
                scores_sum = T.alloc_fragment([block_M], accum_dtype)
                logsum = T.alloc_fragment([block_M], accum_dtype)

132
                T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
133
134
135
136
137
138
                T.fill(acc_o, 0)
                T.fill(logsum, 0)
                T.fill(scores_max, -T.infinity(accum_dtype))

                loop_range = (
                    T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
139
                        (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
140
141
142
143
144
145
146
147
148
149

                for k in T.Pipelined(loop_range, num_stages=num_stages):
                    MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
                    Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
                            scores_sum, logsum)
                    Rescale(acc_o, scores_scale)
                    MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
                for i, j in T.Parallel(block_M, dim):
                    acc_o[i, j] /= logsum[i]
                T.copy(acc_o, O_shared)
150
                T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

        return main

    if tune:

        @autotune(
            configs=get_configs(),
            keys=["block_M", "block_N", "num_stages", "threads"],
            warmup=10,
            rep=10)
        @jit(
            out_idx=[3],
            supply_type=tilelang.TensorSupplyType.Integer,
            ref_prog=None,
            profiler="auto")
        def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
            return kernel_func(block_M, block_N, num_stages, threads)

        return kernel()
    else:

        def kernel(block_M, block_N, num_stages, threads):
            return kernel_func(block_M, block_N, num_stages, threads)

        return kernel


178
def ref_program(Q, K, V, is_causal):
179
    dim = Q.size(-1)
180
    scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
181
    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
182
    if is_causal:
183
184
185
186
187
        seq_len = Q.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
        mask = mask.unsqueeze(0).unsqueeze(0)
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attention_weights = F.softmax(scores, dim=-1)
188
    output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
189
190
191
192
193
194
195
196
197
    return output


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, default=8, help='batch size')
    parser.add_argument('--heads', type=int, default=32, help='heads')
    parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
    parser.add_argument('--dim', type=int, default=128, help='dim')
198
    parser.add_argument('--is_causal', action='store_true', help='causal')
199
200
    parser.add_argument('--tune', action='store_true', help='tune configs')
    args = parser.parse_args()
201
    batch, heads, seq_len, dim, is_causal = args.batch, args.heads, args.seq_len, args.dim, args.is_causal
202
203
    flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
    total_flops = 2 * flops_per_matmul
204
    if is_causal:
205
206
207
208
        total_flops *= 0.5

    if (not args.tune):
        program = flashattn(
209
            batch, heads, seq_len, dim, is_causal, tune=args.tune)(
210
                block_M=128, block_N=128, num_stages=1, threads=128)
211
        ref_program = partial(ref_program, is_causal=is_causal)
212
213
214
215
216
217
218
219
220
221
222
223
        mod, params = tilelang.lower(program)
        mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal)
        mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
        print("All checks pass.")
        latency = mod.do_bench(ref_program, warmup=500)
        print("Ref: {:.2f} ms".format(latency))
        print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
        latency = mod.do_bench(mod.func, warmup=500)
        print("Tile-lang: {:.2f} ms".format(latency))
        print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    else:
        best_latency, best_config, _ = flashattn(
224
            batch, heads, seq_len, dim, is_causal, tune=args.tune)
225
226
227
        print(f"Best latency: {best_latency}")
        print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
        print(f"Best config: {best_config}")