example_tilelang_nsa_fwd.py 6.41 KB
Newer Older
1
2
# ruff: noqa
import torch
3
from reference import naive_nsa
4
5
6
7
8
9
10
import tilelang
from tilelang import language as T
import tilelang.testing

tilelang.testing.set_random_seed(0)


11
@tilelang.jit(
12
13
    out_idx=[-1],
    pass_configs={
14
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
15
16
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
17
18
19
    },
)
def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16):
20
    if scale is None:
21
        scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
22
23
24
    else:
        scale = scale * 1.44269504  # log2(e)

25
26
27
28
    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim]
    kv_shape = [batch, seq_len, head_kv, dim]
    block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
29
30
31
    block_indices_dtype = T.int32
    dtype = T.float16
    accum_dtype = T.float32
32
33
34
35
36
37
38
39
40
41
42
    block_S = block_size
    block_T = min(128, tilelang.math.next_power_of_2(dim))

    NK = tilelang.cdiv(dim, block_T)
    NV = tilelang.cdiv(dim, block_T)
    assert NK == 1, "The key dimension can not be larger than 256"

    S = selected_blocks
    G = groups
    BS = block_S
    BK = BV = block_T
43
    num_stages = 2
44
45
46
47
    threads = 32

    @T.prim_func
    def native_sparse_attention(
48
49
50
51
52
        Q: T.Tensor(q_shape, dtype),
        K: T.Tensor(kv_shape, dtype),
        V: T.Tensor(kv_shape, dtype),
        BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
        Output: T.Tensor(q_shape, dtype),
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    ):
        with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([G, BK], dtype)
            K_shared = T.alloc_shared([BS, BK], dtype)
            V_shared = T.alloc_shared([BS, BV], dtype)
            O_shared = T.alloc_shared([G, BV], dtype)

            acc_s = T.alloc_fragment([G, BS], accum_dtype)
            acc_s_cast = T.alloc_fragment([G, BS], dtype)
            acc_o = T.alloc_fragment([G, BV], accum_dtype)
            scores_max = T.alloc_fragment([G], accum_dtype)
            scores_max_prev = T.alloc_fragment([G], accum_dtype)
            scores_scale = T.alloc_fragment([G], accum_dtype)
            scores_sum = T.alloc_fragment([G], accum_dtype)
            logsum = T.alloc_fragment([G], accum_dtype)

            i_t, i_v, i_bh = bx, by, bz
            i_b, i_h = i_bh // head_kv, i_bh % head_kv

            NS = S
73
            T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared)
74
75
76
77
78
79
80
81
82

            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            for i in T.Pipelined(NS, num_stages=num_stages):
                i_s = BlockIndices[i_b, i_t, i_h, i] * BS
                if i_s <= i_t and i_s >= 0:
                    # [BS, BK]
83
                    T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
84
85
86

                    if is_causal:
                        for i, j in T.Parallel(G, BS):
87
                            acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
88
89
90
                    else:
                        T.clear(acc_s)

91
                    T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

                    # Softmax
                    T.copy(scores_max, scores_max_prev)
                    T.fill(scores_max, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s, scores_max, dim=1, clear=True)
                    for i in T.Parallel(G):
                        scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                    for i, j in T.Parallel(G, BS):
                        acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                    T.reduce_sum(acc_s, scores_sum, dim=1)
                    for i in T.Parallel(G):
                        logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                    T.copy(acc_s, acc_s_cast)

                    # Rescale
                    for i, j in T.Parallel(G, BV):
                        acc_o[i, j] *= scores_scale[i]

                    # V * softmax(Q * K)
111
                    T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
112
113
114
115
116
                    T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

            for i, j in T.Parallel(G, BV):
                acc_o[i, j] /= logsum[i]
            T.copy(acc_o, O_shared)
117
            T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV])
118
119
120
121

    return native_sparse_attention


122
def main():
123
    B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
124

125
    kernel = native_sparse_attention(
126
127
128
129
130
131
132
133
        batch=B,
        heads=HQ,
        seq_len=SEQ_LEN,
        dim=D,
        is_causal=True,
        block_size=block_size,
        groups=HQ // H,
        selected_blocks=S,
134
        scale=scale,
135
    )
136
    print(kernel.get_kernel_source())
137
    torch.random.manual_seed(0)
138
139
140
141
142
143
144
145
146
    Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
    K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
    V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
    g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
    g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
    DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda")

    block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda")
    block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda")
147
148
149
150
    for b in range(B):
        for t in range(SEQ_LEN):
            for h in range(H):
                i_i = torch.randperm(max(1, (t // block_size)))[:S]
151
                block_indices[b, t, h, : len(i_i)] = i_i
liu yuhao's avatar
liu yuhao committed
152
                block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item()
153
154
155
156
    block_indices = block_indices.sort(-1)[0]

    out = kernel(Q, K, V, block_indices.to(torch.int32))

157
    ref = naive_nsa(
158
159
160
        q=Q,
        k=K,
        v=V,
161
162
        g_slc=g_slc,
        g_swa=g_swa,
163
164
        block_indices=block_indices,
        block_counts=block_counts,
165
166
167
        block_size=block_size,
        scale=scale,
    )
168
169
170
171

    print("out", out)
    print("ref", ref)
    torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
172
173
174
175


if __name__ == "__main__":
    main()