example_amd_flash_attn_fwd.py 9.21 KB
Newer Older
1
2
3
4
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
5
from tilelang.tileop.base import GemmWarpPolicy
6
7
8
9
10
import itertools
import argparse
from functools import partial


11
12
13
14
15
# Custom supply function to ensure tensors are created on GPU
def supply_tensors_gpu(params):
    """Supply function that creates tensors on GPU for ROCm/HIP."""
    tensors = []
    for param in params:
16
        if hasattr(param, "shape") and hasattr(param, "dtype"):
17
18
            # Force creation on GPU device
            shape = [int(s) for s in param.shape]
19
            tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
20
21
22
23
24
25
            tensors.append(tensor)
        else:
            tensors.append(param)
    return tensors


26
def ref_program(Q, K, V, is_causal, groups=1):
27
28
    assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
    assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
29
30
31
    dim = Q.size(-1)
    K = K.repeat_interleave(groups, dim=2)
    V = V.repeat_interleave(groups, dim=2)
32
    scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
33
34
35
36
37
    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
    if is_causal:
        seq_len = Q.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
        mask = mask.unsqueeze(0).unsqueeze(0)
38
        scores = scores.masked_fill(mask == 0, float("-inf"))
39
    attention_weights = F.softmax(scores, dim=-1)
40
    output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
41
42
43
44
45
    return output


def get_configs():
    """Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
46
47
    block_M = [32, 64, 128, 256]
    block_N = [32, 64, 128, 256]
48
49
    threads = [128, 256, 512]
    num_split_q = [64, 128, 256]
alex_xiao's avatar
alex_xiao committed
50
    num_stages = [0, 1]
51
52
    enable_rasterization = [True]
    k_pack = [2]
53
    panel_size = [7, 8]
54
55
    qk_coalesced_width = [8]
    v_coalesced_width = [4]
56
57
58

    valid_configs = []

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(
        block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width
    ):
        valid_configs.append(
            {
                "block_M": m,
                "block_N": n,
                "num_split_q": s,
                "threads": t,
                "num_stages": stages,
                "enable_rasterization": r,
                "k_pack": k,
                "panel_size": p,
                "qk_coalesced_width": qkw,
                "v_coalesced_width": vw,
            }
        )
76
77
78
    return valid_configs


79
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
@tilelang.jit(out_idx=[3])
def fast_flashattn(
    batch,
    heads,
    seq_len,
    dim,
    is_causal,
    groups,
    block_M: int,
    block_N: int,
    num_split_q: int,
    threads: int,
    num_stages: int,
    enable_rasterization: bool,
    k_pack: int,
95
96
97
    panel_size: int,
    qk_coalesced_width: int,
    v_coalesced_width: int,
98
):
99
    scale = (1.0 / dim) ** 0.5
100
101
102
103
104
105
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim]
    kv_shape = [batch, seq_len, head_kv, dim]
    dtype = "float16"
    accum_dtype = "float"

106
107
    vec_size = qk_coalesced_width
    v_vec_size = v_coalesced_width
108
109
110

    @T.prim_func
    def main(
111
112
113
114
        Q: T.Tensor(q_shape, dtype),
        K: T.Tensor(kv_shape, dtype),
        V: T.Tensor(kv_shape, dtype),
        Output: T.Tensor(q_shape, dtype),
115
116
    ):
        with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
117
            T.use_swizzle(panel_size, enable=enable_rasterization)
118
119
120
121
122
123
124

            bz = byz_combined // heads
            by = byz_combined % heads

            num_q_blocks = T.ceildiv(seq_len, block_M)

            bx = T.alloc_var("int32")
125
            bx = b_split
126

127
            with T.While(bx < num_q_blocks):
128
129
130
131
132
133
134
                acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
                m_i = T.alloc_fragment([block_M], accum_dtype)
                l_i = T.alloc_fragment([block_M], accum_dtype)
                T.fill(acc_o, 0)
                T.fill(m_i, -T.infinity(accum_dtype))
                T.fill(l_i, 0)

135
                current_bx = bx
136
137
138
139
140
                q_block_offset = current_bx * block_M

                Q_shared = T.alloc_shared([block_M, dim], dtype)
                K_shared = T.alloc_shared([block_N, dim], dtype)
                V_shared = T.alloc_shared([block_N, dim], dtype)
141
142
                # Use register fragment for P instead of shared memory to reduce LDS usage
                acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
143
144
145
146
147

                acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
                m_prev = T.alloc_fragment([block_M], accum_dtype)
                scale_factor = T.alloc_fragment([block_M], accum_dtype)

148
                T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size)
149

150
                loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
151

152
153
                row_sum = T.alloc_fragment([block_M], accum_dtype)

154
155
156
                for k in T.Pipelined(loop_end_k, num_stages=num_stages):
                    kv_idx = k * block_N

157
158
                    T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
                    T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size)
159
160
161

                    if is_causal:
                        for i, j in T.Parallel(block_M, block_N):
162
                            acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype))
163
164
165
166
167
168
169
170
171
172
                    else:
                        T.clear(acc_s)
                    T.gemm(
                        Q_shared,
                        K_shared,
                        acc_s,
                        transpose_B=True,
                        k_pack=k_pack,
                        policy=GemmWarpPolicy.FullRow,
                    )
173
174
175

                    T.copy(m_i, m_prev)
                    T.reduce_max(acc_s, m_i, dim=1, clear=False)
176
177
                    for i in T.Parallel(block_M):
                        m_i[i] = T.max(m_i[i], m_prev[i])
178
179

                    for i in T.Parallel(block_M):
alex_xiao's avatar
alex_xiao committed
180
                        sf = T.exp(m_prev[i] * scale - m_i[i] * scale)
181
182
183
184
185
186
187
                        l_i[i] *= sf
                        scale_factor[i] = sf

                    for i, j in T.Parallel(block_M, dim):
                        acc_o[i, j] *= scale_factor[i]

                    for i, j in T.Parallel(block_M, block_N):
alex_xiao's avatar
alex_xiao committed
188
                        acc_s[i, j] = T.exp(acc_s[i, j] * scale - m_i[i] * scale)
189
190
191
192
193

                    T.reduce_sum(acc_s, row_sum, dim=1)
                    for i in T.Parallel(block_M):
                        l_i[i] += row_sum[i]

194
195
                    # Cast acc_s (accum_dtype) to dtype in registers and directly GEMM with V
                    T.copy(acc_s, acc_s_cast)
196

197
                    T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow)
198
199
200
201
202
203
204
205
206

                l_inv = T.alloc_fragment([block_M], accum_dtype)
                for i in T.Parallel(block_M):
                    safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
                    l_inv[i] = 1.0 / safe_l

                for i, j in T.Parallel(block_M, dim):
                    Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]

207
                bx = current_bx + num_split_q
208
209
210
211

    return main


212
def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1):
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
    total_flops = 2 * flops_per_matmul
    if is_causal:
        total_flops *= 0.5

    print("Starting autotuning for FlashAttention-V2...")
    kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups=groups)
    print(f"Autotuning finished. Best Configuration: {kernel.config}")

    ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)

    profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)

    print("Verifying correctness...")
    profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
    print("All checks pass.")

    latency = profiler.do_bench(ref_program_processed, warmup=100)
    print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")

    latency = profiler.do_bench(warmup=100)
234
    print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
235
236
237
238


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
239
240
241
242
243
244
    parser.add_argument("--batch", type=int, default=1, help="batch size")
    parser.add_argument("--heads", type=int, default=8, help="heads")
    parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
    parser.add_argument("--dim", type=int, default=128, help="dim")
    parser.add_argument("--is_causal", action="store_true", help="causal")
    parser.add_argument("--groups", type=int, default=1, help="groups")
245
246
    args = parser.parse_args()
    main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)