block_sparse_attn_tilelang.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import math
import torch

import tilelang
import tilelang.language as T
import torch.nn.functional as F


def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
    bsz, num_head, downsample_len, _ = x.shape
    # N_CTX = downsample_len * BLOCK
    sparse_index = torch.topk(x, topk, dim=-1).indices
13
    dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    dense_mask.scatter_(-1, sparse_index, True)
    if use_dense_for_last_block:
        dense_mask[:, :, -2:, :] = True
    dense_mask.tril_()
    return dense_mask


def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
    dense_mask = x > threshold
    if use_dense_for_last_block:
        dense_mask[:, :, -2:, :] = True
    dense_mask.tril_()
    return dense_mask


29
@tilelang.jit(
30
31
    out_idx=[4],
    pass_configs={
32
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
33
34
    },
)
35
36
37
38
39
def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal):
    block_M = 64
    block_N = 64
    num_stages = 0
    threads = 128
40
    scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
41
42
43
44
    q_shape = [batch, heads, seq_q, dim]
    kv_shape = [batch, heads, seq_kv, dim]
    block_mask_shape = [batch, heads, downsample_len, downsample_len]

45
46
47
    dtype = T.float16
    accum_dtype = T.float32
    block_mask_dtype = T.int8
48
49
50
51

    def kernel_func(block_M, block_N, num_stages, threads):
        @T.macro
        def Softmax(
52
53
54
55
56
57
58
            acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
            acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
            scores_max: T.FragmentBuffer([block_M], accum_dtype),
            scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
            scores_scale: T.FragmentBuffer([block_M], accum_dtype),
            scores_sum: T.FragmentBuffer([block_M], accum_dtype),
            logsum: T.FragmentBuffer([block_M], accum_dtype),
59
60
61
62
        ):
            T.copy(scores_max, scores_max_prev)
            T.fill(scores_max, -T.infinity(accum_dtype))
            T.reduce_max(acc_s, scores_max, dim=1, clear=False)
63
64
            for i in T.Parallel(block_M):
                scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
            # To do causal softmax, we need to set the scores_max to 0 if it is -inf
            # This process is called Check_inf in FlashAttention3 code, and it only need to be done
            # in the first ceil_div(kBlockM, kBlockN) steps.
            # for i in T.Parallel(block_M):
            #     scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
            for i in T.Parallel(block_M):
                scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
            for i, j in T.Parallel(block_M, block_N):
                # Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
                # max * log_2(e)) This allows the compiler to use the ffma
                # instruction instead of fadd and fmul separately.
                acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
            T.reduce_sum(acc_s, scores_sum, dim=1)
            for i in T.Parallel(block_M):
                logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
            T.copy(acc_s, acc_s_cast)

        @T.macro
        def Rescale(
84
85
            acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
            scores_scale: T.FragmentBuffer([block_M], accum_dtype),
86
87
88
89
90
91
        ):
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] *= scores_scale[i]

        @T.prim_func
        def main(
92
93
94
95
96
            Q: T.Tensor(q_shape, dtype),
            K: T.Tensor(kv_shape, dtype),
            V: T.Tensor(kv_shape, dtype),
            BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
            Output: T.Tensor(q_shape, dtype),
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        ):
            with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
                Q_shared = T.alloc_shared([block_M, dim], dtype)
                K_shared = T.alloc_shared([block_N, dim], dtype)
                V_shared = T.alloc_shared([block_N, dim], dtype)
                O_shared = T.alloc_shared([block_M, dim], dtype)
                acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
                acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
                acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
                scores_max = T.alloc_fragment([block_M], accum_dtype)
                scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
                scores_scale = T.alloc_fragment([block_M], accum_dtype)
                scores_sum = T.alloc_fragment([block_M], accum_dtype)
                logsum = T.alloc_fragment([block_M], accum_dtype)
                block_mask = T.alloc_local([downsample_len], block_mask_dtype)

113
                T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
114
115
116
117
118
119
120
121
122
123
124
                T.fill(acc_o, 0)
                T.fill(logsum, 0)
                T.fill(scores_max, -T.infinity(accum_dtype))

                for vj in T.serial(downsample_len):
                    block_mask[vj] = BlockSparseMask[bz, by, bx, vj]

                loop_range = T.ceildiv(seq_kv, block_N)

                for k in T.Pipelined(loop_range, num_stages=num_stages):
                    if block_mask[k] != 0:
125
                        T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
126
127
128
                        if is_causal:
                            past_len = seq_kv - seq_q
                            for i, j in T.Parallel(block_M, block_N):
129
                                acc_s[i, j] = T.if_then_else(bx * block_M + i + past_len >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
130
131
                        else:
                            T.clear(acc_s)
132
133
134
                        T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

                        Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
135
                        Rescale(acc_o, scores_scale)
136
                        T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
137
138
139
140
141
142
                        T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

                for i, j in T.Parallel(block_M, dim):
                    acc_o[i, j] /= logsum[i]

                T.copy(acc_o, O_shared)
143
                T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
144
145
146
147
148
149
150
151
152
153
154
155
156
157

        return main

    return kernel_func(block_M, block_N, num_stages, threads)


def test_topk_sparse_attention():
    # Config
    BATCH, N_HEADS, SEQ_LEN, D_HEAD = 4, 2, 256, 64
    TOPK = 2  # Keep top 8 elements per row
    BLOCK = 64
    torch.manual_seed(0)

    # Create inputs
158
159
160
    q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
    k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
    v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
161
162
163
164
165
166

    sm_scale = 1.0 / (D_HEAD**0.5)

    # Create sparse mask (downsampled to block level)
    downsample_factor = BLOCK
    downsample_len = math.ceil(SEQ_LEN / downsample_factor)
167
    x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16)
168
169
170
    x_ds[:, :, :, 0] = 100
    block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)

171
    # Run tilelang kernel
172
    kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
173
    tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
174
175
176

    # Compute reference
    # Expand block mask to full attention matrix
177
    full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
178
179
180
181
    full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
    full_mask = full_mask & torch.tril(torch.ones_like(full_mask))  # Apply causal

    # PyTorch reference implementation
182
183
    attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
    attn = attn.masked_fill(~full_mask, float("-inf"))
184
    attn = F.softmax(attn, dim=-1)
185
    ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
186
187
188
189
190

    print("ref_output", ref_output)
    print("tilelang_output", tilelang_output)

    # Verify accuracy
191
    assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), "TileLang output doesn't match reference"
192
193
194
195
196
197
198
199
200
201
202
203
    print("Pass topk sparse attention test with qlen == klen")


def test_topk_sparse_attention_qlen_lt_klen():
    # Config
    BATCH, N_HEADS = 1, 1
    Q_LEN, K_LEN, D_HEAD = 128, 256, 64  # qlen < klen; here, past_len = 256 - 128 = 128.
    TOPK = 1
    BLOCK = 64  # block size used in downsampling
    torch.manual_seed(0)

    # Create inputs.
204
205
206
    q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16)
    k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16)
    v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16)
207
208
209
210
    sm_scale = 1.0 / (D_HEAD**0.5)

    downsample_factor = BLOCK
    downsample_len = math.ceil(K_LEN / downsample_factor)  # number of blocks along one dimension
211
    x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16)
212
213
214
215
    # Force the first column to be high so that the first block is always selected.
    x_ds[:, :, :, 0] = 100
    block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)

216
    kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True)
217
    print(kernel.get_kernel_source())
218
    tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
219
220
221

    past_len = K_LEN - Q_LEN

222
    attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
223

224
    full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
225
226
227
228
229
230
    full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]

    effective_mask = full_mask_full[..., past_len:K_LEN, :]  # shape: (B, H, Q_LEN, K_LEN)

    i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1)  # shape: (Q_LEN, 1)
    j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0)  # shape: (1, K_LEN)
231
    causal_mask = j_global <= i_global  # shape: (Q_LEN, K_LEN)
232
233
234

    final_mask = effective_mask & causal_mask  # shape: (B, H, Q_LEN, K_LEN)

235
    attn = attn.masked_fill(~final_mask, float("-inf"))
236
    attn = F.softmax(attn, dim=-1)
237
    ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
238
239
240
241
242
243
244
245
246
247

    print("ref_output", ref_output)
    print("tilelang_output", tilelang_output)

    # Verify accuracy.
    torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2)

    print("Pass topk sparse attention test with qlen < klen")


248
def main():
249
    test_topk_sparse_attention()
250
    test_topk_sparse_attention_qlen_lt_klen()
251
252
253


if __name__ == "__main__":
254
    main()