example_gqa_fwd_bshd.py 10.7 KB
Newer Older
Lei Wang's avatar
Lei Wang committed
1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial


11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class FlashAttentionTuneSpace:

    def __init__(
        self,
        block_sizes=(64, 128, 256),
        thread_options=(128, 256, 512),
        num_stages_range=(2, 3),
        max_shared_mem=100 * 1024,
        warp_alignment=16,
        dim=128,
        dtype_bytes=2,
    ):
        self.block_sizes = block_sizes
        self.thread_options = thread_options
        self.num_stages_range = num_stages_range
        self.max_shared_mem = max_shared_mem
        self.warp_alignment = warp_alignment
        self.dim = dim
        self.dtype_bytes = dtype_bytes


def get_configs(user_config=None):
    config = user_config or FlashAttentionTuneSpace()
    valid_configs = []

    for block_M, block_N in itertools.product(config.block_sizes, repeat=2):
        for threads in config.thread_options:
            assert threads % 32 == 0
            warp_count = threads // 32
            warp_M = block_M // warp_count
            warp_N = block_N // warp_count

            if (warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0):
                continue

            shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N)
            if shared_mem > config.max_shared_mem:
                continue

            for num_stages in config.num_stages_range:
                valid_configs.append({
                    "block_M": block_M,
                    "block_N": block_N,
                    "num_stages": num_stages,
                    "threads": threads,
                })
    return valid_configs
Lei Wang's avatar
Lei Wang committed
58
59


60
61
62
63
64
65
66
67
68
69
70
71
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
              heads,
              seq_len,
              dim,
              is_causal,
              groups=1,
              block_M=64,
              block_N=64,
              num_stages=0,
              threads=128):
Lei Wang's avatar
Lei Wang committed
72
73
74
75
76
77
78
    scale = (1.0 / dim)**0.5 * 1.44269504  # log2(e)
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim]
    kv_shape = [batch, seq_len, head_kv, dim]
    dtype = "float16"
    accum_dtype = "float"

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    @T.macro
    def MMA0(
        K: T.Tensor(kv_shape, dtype),
        Q_shared: T.SharedBuffer([block_M, dim], dtype),
        K_shared: T.SharedBuffer([block_N, dim], dtype),
        acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
        k: T.int32,
        bx: T.int32,
        by: T.int32,
        bz: T.int32,
    ):
        T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
        if is_causal:
            for i, j in T.Parallel(block_M, block_N):
                acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
                                             -T.infinity(acc_s.dtype))
        else:
            T.clear(acc_s)
        T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

    @T.macro
    def MMA1(
        V: T.Tensor(kv_shape, dtype),
        V_shared: T.SharedBuffer([block_M, dim], dtype),
        acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
        acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
        k: T.int32,
        by: T.int32,
        bz: T.int32,
    ):
        T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
        T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
Lei Wang's avatar
Lei Wang committed
111

112
113
    @T.macro
    def Softmax(
114
115
            acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
            acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
            scores_max: T.FragmentBuffer([block_M], accum_dtype),
            scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
            scores_scale: T.FragmentBuffer([block_M], accum_dtype),
            scores_sum: T.FragmentBuffer([block_M], accum_dtype),
            logsum: T.FragmentBuffer([block_M], accum_dtype),
    ):
        T.copy(scores_max, scores_max_prev)
        T.fill(scores_max, -T.infinity(accum_dtype))
        T.reduce_max(acc_s, scores_max, dim=1, clear=False)
        # To do causal softmax, we need to set the scores_max to 0 if it is -inf
        # This process is called Check_inf in FlashAttention3 code, and it only need to be done
        # in the first ceil_div(kBlockM, kBlockN) steps.
        # for i in T.Parallel(block_M):
        #     scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
        for i in T.Parallel(block_M):
            scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
        for i, j in T.Parallel(block_M, block_N):
            # Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
            # max * log_2(e)) This allows the compiler to use the ffma
            # instruction instead of fadd and fmul separately.
            acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
        T.reduce_sum(acc_s, scores_sum, dim=1)
        for i in T.Parallel(block_M):
            logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
        T.copy(acc_s, acc_s_cast)

    @T.macro
    def Rescale(
144
            acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
145
146
147
148
            scores_scale: T.FragmentBuffer([block_M], accum_dtype),
    ):
        for i, j in T.Parallel(block_M, dim):
            acc_o[i, j] *= scores_scale[i]
Lei Wang's avatar
Lei Wang committed
149

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    @T.prim_func
    def main(
            Q: T.Tensor(q_shape, dtype),
            K: T.Tensor(kv_shape, dtype),
            V: T.Tensor(kv_shape, dtype),
            Output: T.Tensor(q_shape, dtype),
    ):
        with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_M, dim], dtype)
            K_shared = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_N, dim], dtype)
            O_shared = T.alloc_shared([block_M, dim], dtype)
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)

            T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
Lei Wang's avatar
Lei Wang committed
174
175
            T.fill(scores_max, -T.infinity(accum_dtype))

176
177
178
            loop_range = (
                T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
                    (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
Lei Wang's avatar
Lei Wang committed
179

180
181
182
183
184
185
186
187
188
189
            for k in T.Pipelined(loop_range, num_stages=num_stages):
                MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
                Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
                        logsum)
                Rescale(acc_o, scores_scale)
                MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] /= logsum[i]
            T.copy(acc_o, O_shared)
            T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
Lei Wang's avatar
Lei Wang committed
190

191
    return main
Lei Wang's avatar
Lei Wang committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218


def ref_program(Q, K, V, is_causal, groups=1):
    # Q: [B, T, HQ, D]
    # K: [B, T, HK, D]
    # V: [B, T, HV, D]
    # HQ = HKV * groups
    assert Q.size(2) == K.size(
        2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
    assert Q.size(2) == V.size(
        2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"

    dim = Q.size(-1)
    K = K.repeat_interleave(groups, dim=2)
    V = V.repeat_interleave(groups, dim=2)
    scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
    if is_causal:
        seq_len = Q.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
        mask = mask.unsqueeze(0).unsqueeze(0)
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
    return output


219
220
221
222
223
224
225
def main(batch: int = 1,
         heads: int = 64,
         seq_len: int = 4096,
         dim: int = 128,
         is_causal: bool = False,
         groups: int = 16,
         tune: bool = False):
Lei Wang's avatar
Lei Wang committed
226
227
228
229
230
    flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
    total_flops = 2 * flops_per_matmul
    if is_causal:
        total_flops *= 0.5

231
    if (not tune):
232
        kernel = flashattn(
233
234
235
236
237
238
239
240
241
242
            batch,
            heads,
            seq_len,
            dim,
            is_causal,
            groups=groups,
            block_M=64,
            block_N=64,
            num_stages=2,
            threads=128)
243
        ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
244
        profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
245
        profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
Lei Wang's avatar
Lei Wang committed
246
        print("All checks pass.")
247
        latency = profiler.do_bench(ref_program_processed, warmup=500)
Lei Wang's avatar
Lei Wang committed
248
249
        print("Ref: {:.2f} ms".format(latency))
        print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
250
        latency = profiler.do_bench(warmup=500)
Lei Wang's avatar
Lei Wang committed
251
252
253
        print("Tile-lang: {:.2f} ms".format(latency))
        print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    else:
254
255
256
257
        kernel = flashattn(batch, heads, seq_len, dim, is_causal)
        best_latency = kernel.latency
        best_config = kernel.config
        ref_latency = kernel.ref_latency
Lei Wang's avatar
Lei Wang committed
258
259
260
        print(f"Best latency: {best_latency}")
        print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
        print(f"Best config: {best_config}")
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        print(f"Ref latency: {ref_latency}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, default=1, help='batch size')
    parser.add_argument('--heads', type=int, default=64, help='heads')
    parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
    parser.add_argument('--dim', type=int, default=128, help='dim')
    parser.add_argument('--is_causal', action='store_true', help='causal')
    parser.add_argument('--tune', action='store_true', help='tune configs')
    parser.add_argument('--groups', type=int, default=16, help='groups')
    args = parser.parse_args()
    main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune)