example_mha_fwd_varlen.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
# ruff: noqa
import torch
import tilelang
import tilelang.language as T
import tilelang.testing
import argparse

import torch
from einops import rearrange, repeat
10
from varlen_utils import generate_random_padding_mask, generate_qkv
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


def attention_ref(
        q,
        k,
        v,
        query_padding_mask=None,
        key_padding_mask=None,
        causal=False,
        window_size=(-1, -1),  # -1 means infinite window size
        upcast=True,
):
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, head_dim)
        k: (batch_size, seqlen_k, nheads_k, head_dim)
        v: (batch_size, seqlen_k, nheads_k, head_dim)
        query_padding_mask: (batch_size, seqlen_q)
        key_padding_mask: (batch_size, seqlen_k)
        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
        dropout_p: float
        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
        causal: whether to apply causal masking
        window_size: (int, int), left and right window size
        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
            output back to fp16/bf16.
        reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
            without changing the math. This is to estimate the numerical error from operation
            reordering.
    Output:
        output: (batch_size, seqlen_q, nheads, head_dim)
        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
    """
    if causal:
        window_size = (window_size[0], 0)
    dtype_og = q.dtype
    if upcast:
        q, k, v = q.float(), k.float(), v.float()
49
    dim = q.shape[-1]
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    scale = (1.0 / dim)**0.5  # log2(e)
    k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
    v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
    scores = torch.einsum("bthd,bshd->bhts", q, k)
    if key_padding_mask is not None:
        scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
        # scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0)
    scores = scores * scale
    attention = torch.softmax(scores, dim=-1).to(v.dtype)

    # We want to mask here so that the attention matrix doesn't have any NaNs
    # Otherwise we'll get NaN in dV
    if query_padding_mask is not None:
        attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
    output = torch.einsum("bhts,bshd->bthd", attention, v)
    if query_padding_mask is not None:
        output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


70
71
72
73
@tilelang.jit(
    out_idx=[6], pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    })
74
75
76
77
78
79
80
81
82
83
def flashattn(batch_size,
              UQ,
              UKV,
              heads,
              dim,
              is_causal,
              block_M=64,
              block_N=64,
              num_stages=0,
              threads=32):
84
85
86
87
88
89
90
91
92
    scale = (1.0 / dim)**0.5 * 1.44269504  # log2(e)
    q_shape = [UQ, heads, dim]
    k_shape = [UKV, heads, dim]
    v_shape = [UKV, heads, dim]
    o_shape = [UQ, heads, dim]

    dtype = "float16"
    accum_dtype = "float"

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    @T.prim_func
    def main(
            Q_unpad: T.Tensor(q_shape, dtype),
            K_unpad: T.Tensor(k_shape, dtype),
            V_unpad: T.Tensor(v_shape, dtype),
            cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
            cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
            max_seqlen_q: T.int32,
            Output_unpad: T.Tensor(o_shape, dtype),
    ):
        with T.Kernel(
                T.ceildiv(max_seqlen_q, block_M), heads, batch_size,
                threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_M, dim], dtype, "shared")
            K_shared = T.alloc_shared([block_N, dim], dtype, "shared")
            V_shared = T.alloc_shared([block_N, dim], dtype, "shared")
            O_shared = T.alloc_shared([block_M, dim], dtype, "shared")
            acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
            acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_M], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
            scores_scale = T.alloc_fragment([block_M], accum_dtype)
            scores_sum = T.alloc_fragment([block_M], accum_dtype)
            logsum = T.alloc_fragment([block_M], accum_dtype)

            batch_idx = bz
            head_idx = by

            q_start_idx = cu_seqlens_q[batch_idx]
            k_start_idx = cu_seqlens_k[batch_idx]
            v_start_idx = cu_seqlens_k[batch_idx]
            q_end_idx = cu_seqlens_q[batch_idx + 1]
            k_end_idx = cu_seqlens_k[batch_idx + 1]
            v_end_idx = cu_seqlens_k[batch_idx + 1]

            q_current_seqlen = q_end_idx - q_start_idx
            k_current_seqlen = k_end_idx - k_start_idx
            v_current_seqlen = v_end_idx - v_start_idx

            for i, d in T.Parallel(block_M, dim):
                if bx * block_M + i < q_current_seqlen:
                    Q_shared[i, d] = Q_unpad[q_start_idx + bx * block_M + i, head_idx, d]
                else:
                    Q_shared[i, d] = 0

            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            loop_range = T.ceildiv(k_current_seqlen, block_N)

            for k in T.Pipelined(loop_range, num_stages=num_stages):
                # Q * K
                for i, d in T.Parallel(block_N, dim):
                    if k * block_N + i < k_current_seqlen:
                        K_shared[i, d] = K_unpad[k_start_idx + k * block_N + i, head_idx, d]
150
                    else:
151
152
153
154
155
156
157
158
159
160
161
162
163
164
                        K_shared[i, d] = 0
                if is_causal:
                    for i, j in T.Parallel(block_M, block_N):
                        acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
                                                     (bx * block_M + i >= q_current_seqlen or
                                                      k * block_N + j >= k_current_seqlen),
                                                     -T.infinity(acc_s.dtype), 0)
                else:
                    for i, j in T.Parallel(block_M, block_N):
                        acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
                                                      k * block_N + j >= k_current_seqlen),
                                                     -T.infinity(acc_s.dtype), 0)

                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
165

166
167
                # Softmax
                T.copy(scores_max, scores_max_prev)
168
                T.fill(scores_max, -T.infinity(accum_dtype))
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
                # To do causal softmax, we need to set the scores_max to 0 if it is -inf
                # This process is called Check_inf in FlashAttention3 code, and it only need to be done
                # in the first ceil_div(kBlockM, kBlockN) steps.
                # for i in T.Parallel(block_M):
                #     scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
                for i in T.Parallel(block_M):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                for i, j in T.Parallel(block_M, block_N):
                    # Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
                    # max * log_2(e)) This allows the compiler to use the ffma
                    # instruction instead of fadd and fmul separately.
                    acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_M):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                T.copy(acc_s, acc_s_cast)

                # Rescale
                for i, j in T.Parallel(block_M, dim):
                    acc_o[i, j] *= scores_scale[i]
190

191
192
193
194
                # V * softmax(Q * K)
                for i, d in T.grid(block_N, dim):
                    if k * block_N + i < v_current_seqlen:
                        V_shared[i, d] = V_unpad[v_start_idx + k * block_N + i, head_idx, d]
195
                    else:
196
                        V_shared[i, d] = 0
197

198
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
199

200
201
202
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] /= logsum[i]
            T.copy(acc_o, O_shared)
203

204
205
206
            for i, d in T.Parallel(block_M, dim):
                if bx * block_M + i < q_current_seqlen:
                    Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d]
207

208
    return main
209
210


211
def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
212
213
214
215
216
217
218
219
220
221
222
223
224
    flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
    total_flops = 2 * flops_per_matmul

    tilelang.testing.set_random_seed(0)

    causal = False
    if causal:
        total_flops *= 0.5

    dtype = torch.float16
    device = torch.device("cuda")
    window_size = (-1, -1)

225
226
    q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
    k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)

    query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random")
    key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random")
    (
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        q,
        k,
        v,
        output_pad_fn,
        dq_pad_fn,
        dk_pad_fn,
    ) = generate_qkv(
        q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)

    UQ = q_unpad.shape[0]  # unpadded query length
    UK = k_unpad.shape[0]  # unpadded key length
    UKV = k_unpad.shape[0]  # unpadded query key length

252
    kernel = flashattn(batch, UQ, UKV, heads, dim, causal)
253

254
    out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
255
256
257
258
259
260
261
262
263
264
    out = output_pad_fn(out_unpad)

    out_ref, _ = attention_ref(
        q,
        k,
        v,
        query_padding_mask,
        key_padding_mask,
        causal=causal,
    )
265
266
    torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)

267
    import flash_attn
268

269
270
271
272
273
274
275
276
277
278
279
280
    fla_out_unpad = flash_attn.flash_attn_varlen_func(
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        0.0,
        causal=causal,
    )
    fla_out = output_pad_fn(fla_out_unpad)
281
    torch.testing.assert_close(out, fla_out, rtol=1e-2, atol=1e-2)
282

283
    print("All checks passed.✅")
284
285
286
287


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
288
289
290
291
    parser.add_argument('--batch', type=int, default=8, help='batch size')
    parser.add_argument('--heads', type=int, default=64, help='heads')
    parser.add_argument('--seq_len', type=int, default=2048, help='sequence length')
    parser.add_argument('--dim', type=int, default=128, help='dim')
292
293
294

    args = parser.parse_args()
    main(args.batch, args.heads, args.seq_len, args.dim)