example_tilelang_nsa_bwd.py 30 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse

import torch
import triton

import fla
if parse(fla.__version__) < parse("0.2.1"):
    from fla.ops.common.utils import prepare_token_indices
else:
    from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
import tilelang


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
    })
def tilelang_kernel_fwd(
    batch,
    heads,
    seq_len,
    dim,
    is_causal,
    scale=None,
    block_size=64,
    groups=1,
    selected_blocks=16,
):

    from tilelang import language as T

    if scale is None:
        scale = (1.0 / dim)**0.5 * 1.44269504  # log2(e)
    else:
        scale = scale * 1.44269504  # log2(e)

    head_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim]
    kv_shape = [batch, seq_len, head_kv, dim]
    o_slc_shape = [batch, seq_len, heads, dim]
    lse_slc_shape = [batch, seq_len, heads]
    block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
    block_indices_dtype = "int32"
    dtype = "float16"
    accum_dtype = "float"
    block_S = block_size
    block_T = min(128, tilelang.math.next_power_of_2(dim))

    NK = tilelang.cdiv(dim, block_T)
    NV = tilelang.cdiv(dim, block_T)
    assert NK == 1, "The key dimension can not be larger than 256"

    S = selected_blocks
    G = groups
    BS = block_S
    BK = BV = block_T
    num_stages = 0
    threads = 32

    @T.prim_func
    def native_sparse_attention(
            Q: T.Tensor(q_shape, dtype),
            K: T.Tensor(kv_shape, dtype),
            V: T.Tensor(kv_shape, dtype),
            BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
            O_slc: T.Tensor(o_slc_shape, dtype),
            LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
    ):
        with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([G, BK], dtype)
            K_shared = T.alloc_shared([BS, BK], dtype)
            V_shared = T.alloc_shared([BS, BV], dtype)
            O_shared = T.alloc_shared([G, BV], dtype)

            acc_s = T.alloc_fragment([G, BS], accum_dtype)
            acc_s_cast = T.alloc_fragment([G, BS], dtype)
            acc_o = T.alloc_fragment([G, BV], accum_dtype)
            scores_max = T.alloc_fragment([G], accum_dtype)
            scores_max_prev = T.alloc_fragment([G], accum_dtype)
            scores_scale = T.alloc_fragment([G], accum_dtype)
            scores_sum = T.alloc_fragment([G], accum_dtype)
            logsum = T.alloc_fragment([G], accum_dtype)

            i_t, i_v, i_bh = bx, by, bz
            i_b, i_h = i_bh // head_kv, i_bh % head_kv

            NS = S
            T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared)

            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            for i in T.Pipelined(NS, num_stages=num_stages):
                i_s = BlockIndices[i_b, i_t, i_h, i] * BS
                if i_s <= i_t and i_s >= 0:
                    # [BS, BK]
                    T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)

                    if is_causal:
                        for i, j in T.Parallel(G, BS):
                            acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
                                                         -T.infinity(acc_s.dtype))
                    else:
                        T.clear(acc_s)

                    T.gemm(
                        Q_shared,
                        K_shared,
                        acc_s,
                        transpose_B=True,
                        policy=T.GemmWarpPolicy.FullRow,
                    )

                    # Softmax
                    T.copy(scores_max, scores_max_prev)
                    T.fill(scores_max, -T.infinity(accum_dtype))
                    T.reduce_max(acc_s, scores_max, dim=1, clear=True)
                    for i in T.Parallel(G):
                        scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
                    for i, j in T.Parallel(G, BS):
                        acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
                    T.reduce_sum(acc_s, scores_sum, dim=1)
                    for i in T.Parallel(G):
                        logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                    T.copy(acc_s, acc_s_cast)

                    # Rescale
                    for i, j in T.Parallel(G, BV):
                        acc_o[i, j] *= scores_scale[i]

                    # V * softmax(Q * K)
                    T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
                    T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

            for i, j in T.Parallel(G, BV):
                acc_o[i, j] /= logsum[i]
            T.copy(acc_o, O_shared)
            T.copy(
                O_shared,
                O_slc[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV],
            )
            for i in T.Parallel(G):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
            T.copy(logsum, LSE_slc[i_b, i_t, i_h * G:(i_h + 1) * G])

    return native_sparse_attention


@tilelang.jit(pass_configs={
    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_bwd_dkv(
    batch,
    heads,
    seq_len,
    dim,
    is_causal,
    scale=None,
    block_size=64,
    groups=1,
    selected_blocks=16,
    dtype="float16",
    accum_dtype="float",
):
    if scale is None:
        sm_scale = (1.0 / dim)**0.5
    else:
        sm_scale = scale

    scale = sm_scale * 1.44269504

    from tilelang import language as T

    B = batch
    BS = block_size
    G = groups
    V = dim
    K = dim
    BK = tilelang.next_power_of_2(K)
    BV = min(128, tilelang.next_power_of_2(dim))
    NS = tilelang.cdiv(seq_len, BS)
    NV = tilelang.cdiv(V, BV)

    heads_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim]
    k_shape = [batch, seq_len, heads_kv, dim]
    v_shape = [batch, seq_len, heads_kv, dim]
    lse_slc_shape = [batch, seq_len, heads]
    delta_slc_shape = [batch, seq_len, heads]
    o_shape = [batch, heads, seq_len, dim]
    do_slc_shape = [batch, seq_len, heads, dim]
    dk_shape = [NV, batch, seq_len, heads_kv, dim]
    dv_shape = [batch, seq_len, heads_kv, dim]

    block_mask_shape = [batch, seq_len, heads_kv, NS]
    num_threads = 32
    print("NV", NV, "NS", NS, "B", B, "H", H)

    @T.prim_func
    def flash_bwd_dkv(
            Q: T.Tensor(q_shape, dtype),
            K: T.Tensor(k_shape, dtype),
            V: T.Tensor(v_shape, dtype),
            LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
            Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
            DO_slc: T.Tensor(do_slc_shape, dtype),
            DK: T.Tensor(dk_shape, dtype),
            DV: T.Tensor(dv_shape, dtype),
            BlockMask: T.Tensor(block_mask_shape, "int32"),
    ):
        with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
            K_shared = T.alloc_shared([BS, BK], dtype)
            V_shared = T.alloc_shared([BS, BV], dtype)
            Q_shared = T.alloc_shared([G, BK], dtype)
            qkT = T.alloc_fragment([BS, G], accum_dtype)
            qkT_cast = T.alloc_fragment([BS, G], dtype)
            dsT = T.alloc_fragment([BS, G], accum_dtype)
            dsT_cast = T.alloc_fragment([BS, G], dtype)
            lse_shared = T.alloc_shared([G], accum_dtype)
            delta = T.alloc_shared([G], accum_dtype)

            do = T.alloc_shared([G, BV], dtype)
            dv = T.alloc_fragment([BS, BV], accum_dtype)
            dk = T.alloc_fragment([BS, BK], accum_dtype)
            dq = T.alloc_fragment([BS, G], accum_dtype)

            dv_shared = T.alloc_shared([BS, BV], dtype)
            dk_shared = T.alloc_shared([BS, BK], dtype)

            i_b, i_h = i_bh // H, i_bh % H

            T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared)
            T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared)

            # [BS, BK]
            T.clear(dk)
            # [BS, BV]
            T.clear(dv)

            T.annotate_layout({
                K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
                dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
            })

            loop_st = i_s * BS
            loop_ed = seq_len
            for i in T.Pipelined(
                    start=loop_st,
                    stop=loop_ed,
                    num_stages=0,
            ):
                b_m_slc = BlockMask[i_b, i, i_h, i_s]
                if b_m_slc != 0:
                    # [G, BK]
                    T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared)
                    T.clear(qkT)
                    # [BS, BK] @ [G, BK] -> [BS, G]
                    T.gemm(
                        K_shared,
                        Q_shared,
                        qkT,
                        transpose_B=True,
                        policy=T.GemmWarpPolicy.FullRow,
                    )
                    # [G]
                    T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared)

                    for _i, _j in T.Parallel(BS, G):
                        qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j])

                    for _i, _j in T.Parallel(BS, G):
                        qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0)

                    # [G, BV]
                    T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do)
                    T.clear(dsT)
                    # [BS, BV] @ [G, BV] -> [BS, G]
                    T.gemm(
                        V_shared,
                        do,
                        dsT,
                        transpose_B=True,
                        policy=T.GemmWarpPolicy.FullRow,
                    )
                    T.copy(qkT, qkT_cast)
                    # [BS, G] @ [G, BV] -> [BS, BV]
                    T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
                    # [G]
                    T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta)
                    for i, j in T.Parallel(BS, G):
                        dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale

                    # [BS, G] @ [G, BK] -> [BS, BK]
                    T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow)

            T.copy(dv, dv_shared)
            T.copy(dk, dk_shared)
            T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV])
            T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK])

    return flash_bwd_dkv


def make_dq_layout(dQ):
    from tilelang import language as T

    # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
    return T.Layout(
        dQ.shape,
        lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2],
    )


@tilelang.jit(pass_configs={
    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_bwd_dqkv(
    batch,
    heads,
    seq_len,
    dim,
    is_causal,
    scale=None,
    block_size=64,
    groups=1,
    selected_blocks=16,
    dtype="float16",
    accum_dtype="float",
):
    if scale is None:
        sm_scale = (1.0 / dim)**0.5
    else:
        sm_scale = scale

    scale = sm_scale * 1.44269504

    from tilelang import language as T

    B = batch
    BS = block_size
    G = groups
    V = dim
    K = dim
    BK = tilelang.next_power_of_2(K)
    BV = min(128, tilelang.next_power_of_2(dim))
    NS = tilelang.cdiv(seq_len, BS)
    NV = tilelang.cdiv(V, BV)

    heads_kv = heads // groups
    q_shape = [batch, seq_len, heads, dim]
    k_shape = [batch, seq_len, heads_kv, dim]
    v_shape = [batch, seq_len, heads_kv, dim]
    lse_slc_shape = [batch, seq_len, heads]
    delta_slc_shape = [batch, seq_len, heads]
    o_shape = [batch, heads, seq_len, dim]
    do_slc_shape = [batch, seq_len, heads, dim]
    dq_shape = [NV, batch, seq_len, heads, dim]
    dk_shape = [NV, batch, seq_len, heads_kv, dim]
    dv_shape = [batch, seq_len, heads_kv, dim]

    block_mask_shape = [batch, seq_len, heads_kv, NS]
    num_threads = 32

    @T.prim_func
    def flash_bwd_dqkv(
            Q: T.Tensor(q_shape, dtype),
            K: T.Tensor(k_shape, dtype),
            V: T.Tensor(v_shape, dtype),
            LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
            Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
            DO_slc: T.Tensor(do_slc_shape, dtype),
            DQ: T.Tensor(dq_shape, dtype),
            DK: T.Tensor(dk_shape, dtype),
            DV: T.Tensor(dv_shape, dtype),
            BlockMask: T.Tensor(block_mask_shape, "int32"),
    ):
        with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
            K_shared = T.alloc_shared([BS, BK], dtype)
            dsT_shared = T.alloc_shared([BS, G], dtype)
            V_shared = T.alloc_shared([BS, BV], dtype)
            Q_shared = T.alloc_shared([G, BK], dtype)
            qkT = T.alloc_fragment([BS, G], accum_dtype)
            qkT_cast = T.alloc_fragment([BS, G], dtype)
            dsT = T.alloc_fragment([BS, G], accum_dtype)
            dsT_cast = T.alloc_fragment([BS, G], dtype)
            lse_shared = T.alloc_shared([G], accum_dtype)
            delta = T.alloc_shared([G], accum_dtype)

            do = T.alloc_shared([G, BV], dtype)
            dv = T.alloc_fragment([BS, BV], accum_dtype)
            dk = T.alloc_fragment([BS, BK], accum_dtype)
            dq = T.alloc_fragment([G, BK], accum_dtype)

            dv_shared = T.alloc_shared([BS, BV], dtype)
            dk_shared = T.alloc_shared([BS, BK], dtype)

            i_b, i_h = i_bh // H, i_bh % H

            T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared)
            T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared)

            # [BS, BK]
            T.clear(dk)
            # [BS, BV]
            T.clear(dv)

            T.annotate_layout({
                K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
                dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
            })

            loop_st = i_s * BS
            loop_ed = seq_len
            for i in T.Pipelined(
                    start=loop_st,
                    stop=loop_ed,
                    num_stages=0,
            ):
                b_m_slc = BlockMask[i_b, i, i_h, i_s]
                if b_m_slc != 0:
                    # [G, BK]
                    T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared)
                    T.clear(qkT)
                    # [BS, BK] @ [G, BK] -> [BS, G]
                    T.gemm(
                        K_shared,
                        Q_shared,
                        qkT,
                        transpose_B=True,
                        policy=T.GemmWarpPolicy.FullRow,
                    )
                    # [G]
                    T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared)

                    for _i, _j in T.Parallel(BS, G):
                        qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j])

                    for _i, _j in T.Parallel(BS, G):
                        qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0)

                    # [G, BV]
                    T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do)
                    T.clear(dsT)
                    # [BS, BV] @ [G, BV] -> [BS, G]
                    T.gemm(
                        V_shared,
                        do,
                        dsT,
                        transpose_B=True,
                        policy=T.GemmWarpPolicy.FullRow,
                    )
                    T.copy(qkT, qkT_cast)
                    # [BS, G] @ [G, BV] -> [BS, BV]
                    T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
                    # [G]
                    T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta)
                    for i, j in T.Parallel(BS, G):
                        dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale

                    # [BS, G] @ [G, BK] -> [BS, BK]
                    T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow)

                    T.copy(dsT_cast, dsT_shared)
                    T.clear(dq)
                    # [BS, G] * [BS, BK] -> [G, BK]
                    T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
                    for _i, _j in T.Parallel(G, BK):
                        T.atomic_add(DQ[i_v, i_b, i, i_h * G + _i, _j], dq[_i, _j])

            T.copy(dv, dv_shared)
            T.copy(dk, dk_shared)
            T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV])
            T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK])

    return flash_bwd_dqkv


@tilelang.jit(
    out_idx=[2], pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    })
def tilelang_kernel_preprocess(
    batch,
    heads,
    seq_len,
    dim,
    dtype="float16",
    accum_dtype="float",
    blk=32,
):
    from tilelang import language as T

    shape = [batch, seq_len, heads, dim]

    @T.prim_func
    def flash_bwd_prep(
            O: T.Tensor(shape, dtype),  # type: ignore
            dO: T.Tensor(shape, dtype),  # type: ignore
            Delta: T.Tensor([batch, seq_len, heads], accum_dtype),  # type: ignore
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
            o = T.alloc_fragment([blk, blk], dtype)
            do = T.alloc_fragment([blk, blk], dtype)
            acc = T.alloc_fragment([blk, blk], accum_dtype)
            delta = T.alloc_fragment([blk], accum_dtype)
            T.clear(acc)
            for k in range(T.ceildiv(dim, blk)):
                T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
                T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
                for i, j in T.Parallel(blk, blk):
                    acc[i, j] += o[i, j] * do[i, j]
            T.reduce_sum(acc, delta, 1)
            T.copy(delta, Delta[bz, by * blk:(by + 1) * blk, bx])

    return flash_bwd_prep


@tilelang.jit(
    out_idx=[2], pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
    })
def tilelang_kernel_block_mask(
    batch,
    heads,
    seq_len,
    selected_blocks,
    block_size,
    dtype="int32",
):
    from tilelang import language as T

    block_indices_shape = [batch, seq_len, heads, selected_blocks]
    block_counts_shape = [batch, seq_len, heads]
    S = selected_blocks
    BS = block_size
    NS = tilelang.cdiv(seq_len, BS)

    block_mask_shape = [batch, seq_len, heads, NS]
    USE_BLOCK_COUNTS = block_counts is not None

    @T.prim_func
    def flash_bwd_block_mask(
            BlockIndices: T.Tensor(block_indices_shape, dtype),  # type: ignore
            BlockCounts: T.Tensor(block_counts_shape, dtype),  # type: ignore
            BlockMask: T.Tensor(block_mask_shape, dtype),  # type: ignore
    ):
        with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz):
            i_t, i_b, i_hs = bx, by, bz
            i_h, i_s = i_hs // S, i_hs % S
            b_i = BlockIndices[i_b, i_t, i_h, i_s]
            if USE_BLOCK_COUNTS:
                b_m = b_i * BS <= i_t and i_s < BlockCounts[i_b, i_t, i_h].astype(i_s.dtype)
                BlockMask[i_b, i_t, i_h, i_s] = b_m
            else:
                b_m = b_i * BS <= i_t
                BlockMask[i_b, i_t, i_h, i_s] = b_m

    return flash_bwd_block_mask


def parallel_nsa_bwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    o_slc: torch.Tensor,
    lse_slc: torch.Tensor,
    do_slc: torch.Tensor,
    o_swa: torch.Tensor,
    lse_swa: torch.Tensor,
    do_swa: torch.Tensor,
    block_indices: torch.Tensor,
    block_counts: Union[torch.LongTensor, int],
    block_size: int = 64,
    window_size: int = 0,
    scale: float = None,
    offsets: Optional[torch.LongTensor] = None,
    token_indices: Optional[torch.LongTensor] = None,
):
    B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
    HQ = q.shape[2]
    G = HQ // H
    BS = block_size
    WS = window_size
    BK = triton.next_power_of_2(K)
    BV = min(128, triton.next_power_of_2(v.shape[-1]))
    NV = triton.cdiv(V, BV)

    assert window_size == 0, "Window size is not supported yet"
    delta_slc = tilelang_kernel_preprocess(B, HQ, T, K)(o_slc, do_slc)

    dq = torch.zeros(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
    dk = torch.empty(NV, *k.shape, dtype=k.dtype, device=q.device)
    dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)

    block_mask = tilelang_kernel_block_mask(B, H, T, S,
                                            BS)(block_indices.to(torch.int32),
                                                block_counts.to(torch.int32)).to(torch.bool)

    fused_qkv_bwd_kernel = tilelang_kernel_bwd_dqkv(
        batch=B,
        heads=HQ,
        seq_len=T,
        dim=K,
        is_causal=True,
        block_size=BS,
        groups=G,
        selected_blocks=S,
        scale=scale,
    )
    fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv,
                         block_mask.to(torch.int32))

    dq = dq.sum(0)
    dk = dk.sum(0)
    return dq, dk, dv


@torch.compile
class ParallelNSAFunction(torch.autograd.Function):

    @staticmethod
    @contiguous
    @autocast_custom_fwd
    def forward(
        ctx,
        q,
        k,
        v,
        block_indices,
        block_counts,
        block_size,
        window_size,
        scale,
        offsets,
    ):
        ctx.dtype = q.dtype
        assert offsets is None, "Offsets are not supported yet"
        # 2-d sequence indices denoting the offsets of tokens in each sequence
        # for example, if the passed `offsets` is [0, 2, 6],
        # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
        # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
        token_indices = prepare_token_indices(offsets) if offsets is not None else None

        B, SEQLEN, HQ, D = q.shape
        H = k.shape[2]
        G = HQ // H
        S = block_indices.shape[-1]
        V = v.shape[-1]
        kernel = tilelang_kernel_fwd(
            batch=B,
            heads=HQ,
            seq_len=SEQLEN,
            dim=D,
            is_causal=True,
            scale=scale,
            block_size=block_size,
            groups=G,
            selected_blocks=S,
        )
        o_slc = torch.empty(B, SEQLEN, HQ, D, dtype=v.dtype, device=q.device)
        lse_slc = torch.empty(B, SEQLEN, HQ, dtype=torch.float, device=q.device)
        kernel(q, k, v, block_indices.to(torch.int32), o_slc, lse_slc)

        ctx.save_for_backward(q, k, v, o_slc, lse_slc)
        ctx.block_indices = block_indices
        ctx.block_counts = block_counts
        ctx.offsets = offsets
        ctx.token_indices = token_indices
        ctx.block_size = block_size
        ctx.window_size = window_size
        ctx.scale = scale
        return o_slc.to(q.dtype), lse_slc.to(torch.float)

    @staticmethod
    @contiguous
    @autocast_custom_bwd
    def backward(ctx, do_slc, do_swa):
        q, k, v, o_slc, lse_slc = ctx.saved_tensors
        dq, dk, dv = parallel_nsa_bwd(
            q=q,
            k=k,
            v=v,
            o_slc=o_slc,
            o_swa=None,
            lse_slc=lse_slc,
            lse_swa=None,
            do_slc=do_slc,
            do_swa=do_swa,
            block_indices=ctx.block_indices,
            block_counts=ctx.block_counts,
            block_size=ctx.block_size,
            window_size=ctx.window_size,
            scale=ctx.scale,
            offsets=ctx.offsets,
            token_indices=ctx.token_indices,
        )
        return (
            dq.to(q),
            dk.to(k),
            dv.to(v),
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


def parallel_nsa(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g_slc: torch.Tensor,
    g_swa: torch.Tensor,
    block_indices: torch.LongTensor,
    block_counts: Optional[Union[torch.LongTensor, int]] = None,
    block_size: int = 64,
    window_size: int = 0,
    scale: Optional[float] = None,
    cu_seqlens: Optional[torch.LongTensor] = None,
    head_first: bool = False,
) -> torch.Tensor:
    r"""
    Args:
        q (torch.Tensor):
            queries of shape `[B, SEQLEN, HQ, K]` if `head_first=False` else `[B, HQ, SEQLEN, K]`.
        k (torch.Tensor):
            keys of shape `[B, SEQLEN, H, K]` if `head_first=False` else `[B, H, SEQLEN, K]`.
            GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
        v (torch.Tensor):
            values of shape `[B, SEQLEN, H, V]` if `head_first=False` else `[B, H, SEQLEN, V]`.
        g_slc (torch.Tensor):
            Gate score for selected attention of shape `[B, SEQLEN, HQ]` if  `head_first=False` else `[B, HQ, SEQLEN]`.
        g_swa (torch.Tensor):
            Gate score for sliding attention of shape `[B, SEQLEN, HQ]` if  `head_first=False` else `[B, HQ, SEQLEN]`.
        block_indices (torch.LongTensor):
            Block indices of shape `[B, SEQLEN, H, S]` if `head_first=False` else `[B, H, SEQLEN, S]`.
            `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
        block_counts (Union[torch.LongTensor, int]):
            Number of selected blocks for each token.
            If a tensor is provided, with shape `[B, SEQLEN, H]` if `head_first=True` else `[B, SEQLEN, H]`,
            each token can select the same number of blocks.
            If not provided, it will default to `S`, Default: `None`
        block_size (int):
            Selected block size. Default: 64.
        window_size (int):
            Sliding window size. Default: 0.
        scale (Optional[int]):
            Scale factor for attention scores.
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
        head_first (Optional[bool]):
            Whether the inputs are in the head-first format. Default: `False`.
        cu_seqlens (torch.LongTensor):
            Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
            consistent with the FlashAttention API.

    Returns:
        o (torch.Tensor):
            Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`.
    """
    if scale is None:
        scale = k.shape[-1]**-0.5
    if cu_seqlens is not None:
        assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
    if head_first:
        q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"),
                                     (q, k, v, block_indices))
        g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
        if isinstance(block_counts, torch.Tensor):
            block_counts = rearrange(block_counts, "b h t -> b t h")
    assert (q.shape[2] % (k.shape[2] * 16) == 0), "Group size must be a multiple of 16 in NSA"

    if isinstance(block_counts, int):
        block_indices = block_indices[:, :, :, :block_counts]
        block_counts = None

    o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
                                             window_size, scale, cu_seqlens)
    if window_size > 0:
        o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
    else:
        o = o_slc * g_slc.unsqueeze(-1)
    if head_first:
        o = rearrange(o, "b t h d -> b h t d")
    return o


if __name__ == "__main__":
    B, T, H, HQ, D, S, block_size, dtype = 1, 32, 1, 16, 32, 1, 32, torch.float16
    torch.random.manual_seed(0)
    q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
    k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
    v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
    g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
    g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
    do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda")

    block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda")
    for b in range(B):
        for t in range(T):
            for h in range(H):
                i_i = torch.randperm(max(1, (t // block_size)))[:S]
                block_indices[b, t, h, :len(i_i)] = i_i
    block_indices = block_indices.sort(-1)[0]

    block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda")

    ref = naive_nsa(
        q=q,
        k=k,
        v=v,
        g_slc=g_slc,
        g_swa=g_swa,
        block_indices=block_indices,
        block_counts=block_counts,
        block_size=block_size,
    )
    ref.backward(do)
    ref_dq, q.grad = q.grad.clone(), None
    ref_dk, k.grad = k.grad.clone(), None
    ref_dv, v.grad = v.grad.clone(), None
    ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None

    tri = parallel_nsa(
        q=q,
        k=k,
        v=v,
        g_slc=g_slc,
        g_swa=g_swa,
        block_indices=block_indices,
        block_size=block_size,
        block_counts=block_counts,
    )
    tri.backward(do)
    tri_dq, q.grad = q.grad.clone(), None
    tri_dk, k.grad = k.grad.clone(), None
    tri_dv, v.grad = v.grad.clone(), None
    tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None

    # assert_close(" o", ref, tri, 0.004)
    torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
    torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=1e-2)
    torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=1e-2)
    torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
    torch.testing.assert_close(ref_dg_slc, tri_dg_slc, atol=1e-2, rtol=1e-2)