example_tilelang_nsa_decode.py 6.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# ruff: noqa
import torch
from reference import naive_nsa_simple_inference
import tilelang
from tilelang import language as T
import tilelang.testing

tilelang.testing.set_random_seed(42)


11
12
13
14
15
16
17
# TODO(lei): workaround, as threads is not divisible by warp group size,
# auto warp specialization may have some bugs.
@tilelang.jit(
    out_idx=[-1],
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
18
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
19
20
    },
)
21
22
23
24
25
26
27
28
def native_sparse_attention(
    batch,
    heads,
    seq_len,  # Length of K/V sequences (context window size)
    dim,  # Embedding dimension per head
    scale=None,
    block_size=64,  # Tile size for attention computation
    groups=1,  # Grouped query attention (GQA) groups
29
    selected_blocks=16,  # Number of blocks to select per attention head
30
31
):
    if scale is None:
32
        scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
33
    head_kv = heads // groups
34
    # Modified shapes for inference (q has seq_len=1)a
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    q_shape = [batch, 1, heads, dim]  # Changed seq_len to 1
    kv_shape = [batch, seq_len, head_kv, dim]
    block_indices_shape = [batch, 1, head_kv, selected_blocks]  # Changed seq_len to 1
    block_indices_dtype = "int32"
    dtype = "float16"
    accum_dtype = "float"
    block_S = block_size
    block_T = min(128, tilelang.math.next_power_of_2(dim))

    NK = tilelang.cdiv(dim, block_T)
    NV = tilelang.cdiv(dim, block_T)
    assert NK == 1, "The key dimension can not be larger than 256"

    S = selected_blocks
    G = groups
    BS = block_S
    BK = BV = block_T
    num_stages = 0
    threads = 32

    @T.prim_func
    def native_sparse_attention(
57
58
59
60
61
        Q: T.Tensor(q_shape, dtype),  # [batch, 1, heads, dim]
        K: T.Tensor(kv_shape, dtype),  # [batch, seq_len, head_kv, dim]
        V: T.Tensor(kv_shape, dtype),  # Same shape as K
        BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),  # Selected block indices
        Output: T.Tensor(q_shape, dtype),  # Output attention tensor
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    ):
        with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz):
            # Shared memory allocations for tile storage
            Q_shared = T.alloc_shared([G, BK], dtype)  # Current query block
            K_shared = T.alloc_shared([BS, BK], dtype)  # Current key block
            V_shared = T.alloc_shared([BS, BV], dtype)  # Current value block
            O_shared = T.alloc_shared([G, BV], dtype)  # Output accumulator

            # Attention computation buffers
            acc_s = T.alloc_fragment([G, BS], accum_dtype)  # QK^T scores
            acc_s_cast = T.alloc_fragment([G, BS], dtype)  # Casted scores for softmax
            acc_o = T.alloc_fragment([G, BV], accum_dtype)  # Output accumulator
            scores_max = T.alloc_fragment([G], accum_dtype)
            scores_max_prev = T.alloc_fragment([G], accum_dtype)
            scores_scale = T.alloc_fragment([G], accum_dtype)
            scores_sum = T.alloc_fragment([G], accum_dtype)
            logsum = T.alloc_fragment([G], accum_dtype)

            i_v, i_bh = by, bz
            i_b, i_h = i_bh // head_kv, i_bh % head_kv

            NS = S
            # Copy Q for the single position
85
            T.copy(Q[i_b, 0, i_h * G : (i_h + 1) * G, :], Q_shared)  # Changed i_t to 0
86
87
88
89
90
91
92
93
94
95

            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            # Main attention computation loop over selected blocks
            for i in T.Pipelined(NS, num_stages=num_stages):
                i_s = BlockIndices[i_b, 0, i_h, i] * BS  # Get block offset
                if i_s >= 0:  # Skip invalid/padding blocks
                    # Load current key block to shared memory
96
                    T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
97
98
99

                    # Compute QK^T attention scores
                    T.clear(acc_s)
100
                    T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

                    # Online softmax with numerical stability
                    # 1. Compute max for scaling
                    # 2. Compute exponentials and sum
                    # 3. Maintain running logsum for normalization
                    T.copy(scores_max, scores_max_prev)
                    T.fill(scores_max, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s, scores_max, dim=1, clear=True)

                    for i in T.Parallel(G):
                        scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                    for i, j in T.Parallel(G, BS):
                        acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                    T.reduce_sum(acc_s, scores_sum, dim=1)
                    for i in T.Parallel(G):
                        logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                    T.copy(acc_s, acc_s_cast)

                    # Accumulate attention-weighted values
120
                    T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
121
122
123
124
125
126
                    T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

            # Final normalization and output
            for i, j in T.Parallel(G, BV):
                acc_o[i, j] /= logsum[i]  # Normalize by logsum
            T.copy(acc_o, O_shared)
127
            T.copy(O_shared, Output[i_b, 0, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV])  # Changed i_t to 0
128
129
130
131

    return native_sparse_attention


132
def main():
133
134
135
    B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16
    groups = HQ // H
    SEQ_LEN_Q = 1
136
    kernel = native_sparse_attention(
137
138
139
140
141
142
143
144
145
        batch=B,
        heads=HQ,
        seq_len=SEQ_LEN,
        dim=D,
        block_size=block_size,
        groups=HQ // H,
        selected_blocks=S,
    )

146
147
148
    Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
    K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
    V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
149

150
151
    mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda")
    DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda")
152

153
    block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda")
154
155
156
157
    for b in range(B):
        for t in range(SEQ_LEN_Q):
            for h in range(H):
                i_i = torch.randperm(max(1, (t // block_size)))[:S]
158
                block_indices[b, t, h, : len(i_i)] = i_i
159
    block_indices = block_indices.sort(-1)[0]
160
    block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda")
161
162
163
164
165
166
167
168
169
170
171
172

    out = kernel(Q, K, V, block_indices.to(torch.int32))

    ref = naive_nsa_simple_inference(
        q=Q,
        k=K,
        v=V,
        block_indices=block_indices,
        block_counts=block_counts,
        block_size=block_size,
    )
    torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
173
174
175
176


if __name__ == "__main__":
    main()