example_tilelang_nsa_fwd_varlen.py 13.9 KB
Newer Older
1
2
3
# ruff: noqa
import torch
from typing import Optional, Union
4
5
from packaging.version import parse

6
7
8
9
import tilelang
from tilelang import language as T
import tilelang.testing

10
import fla
11

12
13
14
15
if parse(fla.__version__) < parse("0.2.1"):
    from fla.ops.common.utils import prepare_token_indices
else:
    from fla.ops.utils import prepare_token_indices
16
17
18
19
from reference import naive_nsa
from einops import rearrange


20
21
22
23
24
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
25
26
27
    }
)
def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16):
28
    if scale is None:
29
        scale = (1.0 / dim) ** 0.5 * 1.44269504  # log2(e)
30
31
32
33
34
35
36
37
38
39
40
    head_kv = heads // groups
    q_shape = [c_seq_len, heads, dim]
    kv_shape = [c_seq_len, head_kv, dim]
    o_slc_shape = [c_seq_len, heads, dim]
    o_swa_shape = [c_seq_len, heads, dim]
    lse_slc_shape = [c_seq_len, heads]
    lse_swa_shape = [c_seq_len, heads]
    block_indices_shape = [c_seq_len, head_kv, selected_blocks]
    block_counts_shape = [c_seq_len, head_kv]
    offsets_shape = [batch + 1]
    token_indices_shape = [c_seq_len, 2]
41
42
43
44
45
46
    block_indices_dtype = T.int32
    block_counts_dtype = T.int32
    offsets_dtype = T.int32
    token_indices_dtype = T.int32
    dtype = T.float16
    accum_dtype = T.float32
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    block_S = block_size
    block_T = min(128, tilelang.math.next_power_of_2(dim))

    NK = tilelang.cdiv(dim, block_T)
    NV = tilelang.cdiv(dim, block_T)
    assert NK == 1, "The key dimension can not be larger than 256"

    S = selected_blocks
    G = groups
    BS = block_S
    BK = BV = block_T
    num_stages = 0
    threads = 32

    @T.prim_func
    def native_sparse_attention_varlen(
63
64
65
66
67
68
69
70
        Q: T.Tensor(q_shape, dtype),
        K: T.Tensor(kv_shape, dtype),
        V: T.Tensor(kv_shape, dtype),
        O_slc: T.Tensor(o_slc_shape, dtype),
        BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
        BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype),
        Offsets: T.Tensor(offsets_shape, offsets_dtype),
        TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype),
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    ):
        with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([G, BK], dtype)
            K_shared = T.alloc_shared([BS, BK], dtype)
            V_shared = T.alloc_shared([BS, BV], dtype)
            O_shared = T.alloc_shared([G, BV], dtype)

            acc_s = T.alloc_fragment([G, BS], accum_dtype)
            acc_s_cast = T.alloc_fragment([G, BS], dtype)
            acc_o = T.alloc_fragment([G, BV], accum_dtype)
            scores_max = T.alloc_fragment([G], accum_dtype)
            scores_max_prev = T.alloc_fragment([G], accum_dtype)
            scores_scale = T.alloc_fragment([G], accum_dtype)
            scores_sum = T.alloc_fragment([G], accum_dtype)
            logsum = T.alloc_fragment([G], accum_dtype)

            i_c, i_v, i_bh = bx, by, bz
            i_b, i_h = i_bh // head_kv, i_bh % head_kv

            i_n, i_t = TokenIndices[i_c, 0], TokenIndices[i_c, 1]

            bos = Offsets[i_n]
            eos = Offsets[i_n + 1]
            current_seq_len = eos - bos

            NS = BlockCounts[i_t, i_h]
97
            T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared)
98
99
100
101
102
103
104
105
106
107
108

            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            for i in T.Pipelined(NS, num_stages=num_stages):
                i_s = BlockIndices[bos + i_t, i_h, i] * BS
                if i_s <= i_t and i_s >= 0:
                    # [BS, BK]
                    # Lei: may have some padding issues
                    # we should learn from mha varlen templates to handle this
109
                    T.copy(K[bos + i_s : bos + i_s + BS, i_h, :BK], K_shared)
110
111
112

                    if is_causal:
                        for i, j in T.Parallel(G, BS):
113
                            acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
114
115
116
                    else:
                        T.clear(acc_s)

117
                    T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

                    # Softmax
                    T.copy(scores_max, scores_max_prev)
                    T.fill(scores_max, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s, scores_max, dim=1, clear=True)
                    for i in T.Parallel(G):
                        scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                    for i, j in T.Parallel(G, BS):
                        acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                    T.reduce_sum(acc_s, scores_sum, dim=1)
                    for i in T.Parallel(G):
                        logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                    T.copy(acc_s, acc_s_cast)

                    # Rescale
                    for i, j in T.Parallel(G, BV):
                        acc_o[i, j] *= scores_scale[i]

                    # V * softmax(Q * K)
137
                    T.copy(V[bos + i_s : bos + i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
138
139
140
141
142
                    T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

            for i, j in T.Parallel(G, BV):
                acc_o[i, j] /= logsum[i]
            T.copy(acc_o, O_shared)
143
            T.copy(O_shared, O_slc[bos + i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV])
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

    return native_sparse_attention_varlen


def parallel_nsa_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    block_indices: torch.LongTensor,
    block_counts: Union[torch.LongTensor, int],
    block_size: int,
    window_size: int,
    scale: float,
    offsets: Optional[torch.LongTensor] = None,
    token_indices: Optional[torch.LongTensor] = None,
):
    B, C_SEQ_LEN, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]

    batch = len(offsets) - 1
    HQ = q.shape[2]
    G = HQ // H
    BS = block_size
    WS = window_size

168
    kernel = native_sparse_attention_varlen(
169
170
171
172
173
174
175
176
177
178
179
180
        batch=batch,
        heads=HQ,
        c_seq_len=C_SEQ_LEN,
        dim=K,
        is_causal=True,
        block_size=block_size,
        groups=G,
        selected_blocks=S,
    )

    o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device)
    kernel(
181
182
183
        q.view(C_SEQ_LEN, HQ, D),
        k.view(C_SEQ_LEN, H, D),
        v.view(C_SEQ_LEN, H, D),
184
185
        o_slc.view(C_SEQ_LEN, HQ, V),
        block_indices.to(torch.int32).view(C_SEQ_LEN, H, S),
186
187
188
189
        block_counts.to(torch.int32).view(C_SEQ_LEN, H),
        offsets.to(torch.int32),
        token_indices.to(torch.int32),
    )
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    return o_slc


@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
        ctx.dtype = q.dtype

        # 2-d sequence indices denoting the offsets of tokens in each sequence
        # for example, if the passed `offsets` is [0, 2, 6],
        # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
        # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
        token_indices = prepare_token_indices(offsets) if offsets is not None else None

        o_slc = parallel_nsa_fwd(
            q=q,
            k=k,
            v=v,
            block_indices=block_indices,
            block_counts=block_counts,
            block_size=block_size,
            window_size=window_size,
            scale=scale,
            offsets=offsets,
215
216
            token_indices=token_indices,
        )
217
218
219
        return o_slc.to(q.dtype)


220
221
222
223
224
225
226
227
228
229
230
231
232
233
def parallel_nsa(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g_slc: torch.Tensor,
    g_swa: torch.Tensor,
    block_indices: torch.LongTensor,
    block_counts: Optional[Union[torch.LongTensor, int]] = None,
    block_size: int = 64,
    window_size: int = 0,
    scale: Optional[float] = None,
    cu_seqlens: Optional[torch.LongTensor] = None,
    head_first: bool = False,
) -> torch.Tensor:
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    r"""
    Args:
        q (torch.Tensor):
            queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
        k (torch.Tensor):
            keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
            GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
        v (torch.Tensor):
            values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
        g_slc (torch.Tensor):
            Gate score for selected attention of shape `[B, T, HQ]` if  `head_first=False` else `[B, HQ, T]`.
        g_swa (torch.Tensor):
            Gate score for sliding attentionof shape `[B, T, HQ]` if  `head_first=False` else `[B, HQ, T]`.
        block_indices (torch.LongTensor):
            Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
            `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
        block_counts (Union[torch.LongTensor, int]):
            Number of selected blocks for each token.
            If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
            each token can select the same number of blocks.
            If not provided, it will default to `S`, Default: `None`
        block_size (int):
            Selected block size. Default: 64.
        window_size (int):
            Sliding window size. Default: 0.
        scale (Optional[int]):
            Scale factor for attention scores.
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
        head_first (Optional[bool]):
            Whether the inputs are in the head-first format. Default: `False`.
        cu_seqlens (torch.LongTensor):
            Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
            consistent with the FlashAttention API.

    Returns:
        o (torch.Tensor):
            Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
    """
    if scale is None:
273
        scale = k.shape[-1] ** -0.5
274
275
276
    if cu_seqlens is not None:
        assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
    if head_first:
277
278
        q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
        g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
279
        if isinstance(block_counts, torch.Tensor):
280
            block_counts = rearrange(block_counts, "b h t -> b t h")
281
282
283
284
285
286
    assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"

    if isinstance(block_counts, int):
        block_indices = block_indices[:, :, :, :block_counts]
        block_counts = None

287
    o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
288
289
290
291
292
    if window_size > 0:
        assert False, "Window size is not supported yet"
    else:
        o = o_slc * g_slc.unsqueeze(-1)
    if head_first:
293
        o = rearrange(o, "b t h d -> b h t d")
294
295
296
297
298
299
300
    return o


if __name__ == "__main__":
    N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16
    torch.manual_seed(42)
    # randomly split the sequence into N segments
301
302
303
304
305
306
307
308
309
310
311
312
    offsets = (
        torch.cat(
            [
                torch.tensor([0], dtype=torch.long),
                torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[: N - 1]],
                torch.tensor([C_SEQ_LEN], dtype=torch.long),
            ],
            0,
        )
        .cuda()
        .sort()[0]
    )
313
314

    # seq-first required for inputs with variable lengths
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    perm_q = torch.randperm(C_SEQ_LEN, device="cuda")
    perm_k = torch.randperm(C_SEQ_LEN, device="cuda")
    perm_v = torch.randperm(C_SEQ_LEN, device="cuda")
    q = (
        torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_q]
        .view(1, C_SEQ_LEN, 1, 1)
        .expand(1, C_SEQ_LEN, HQ, D)
        .clone()
        .requires_grad_(True)
    )
    k = (
        torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_k]
        .view(1, C_SEQ_LEN, 1, 1)
        .expand(1, C_SEQ_LEN, H, D)
        .clone()
        .requires_grad_(True)
    )
    v = (
        torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_v]
        .view(1, C_SEQ_LEN, 1, 1)
        .expand(1, C_SEQ_LEN, H, D)
        .clone()
        .requires_grad_(True)
    )
    g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
    g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
    do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device="cuda")
342
343

    token_indices = prepare_token_indices(offsets).tolist()
344
    block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device="cuda")
345
346
347
348
    for i in range(C_SEQ_LEN):
        _, t = token_indices[i]
        for h in range(H):
            i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S]
349
            block_indices[0, i, h, : len(i_i)] = i_i
350
    block_indices = block_indices.sort(-1)[0]
351
    block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device="cuda")
352
353
354
355
356
357
358
359
360
361

    ref = naive_nsa(
        q=q,
        k=k,
        v=v,
        g_slc=g_slc,
        g_swa=g_swa,
        block_indices=block_indices,
        block_counts=block_counts,
        block_size=block_size,
362
363
        cu_seqlens=offsets,
    )
364
365
366
367
368
369
370
371
372
373

    tri = parallel_nsa(
        q=q,
        k=k,
        v=v,
        g_slc=g_slc,
        g_swa=g_swa,
        block_indices=block_indices,
        block_counts=block_counts,
        block_size=block_size,
374
375
        cu_seqlens=offsets,
    )
376
377
378
379
380

    print("tri", tri)
    print("ref", ref)

    torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)