flash_mla_interface_triton.py 42.9 KB
Newer Older
wangkx1's avatar
init  
wangkx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
from typing import Optional, Tuple
import dataclasses

import torch

import flash_mla.cuda as flash_mla_cuda




import torch

import triton
import triton.language as tl
import math
from typing import Optional, Tuple

from dataclasses import dataclass

@triton.jit
def _flash_mla_kernel(
    # 输入输出指针
    q_ptr,
    k_cache_ptr,
    v_cache_ptr,
    block_table_ptr,
    cache_seqlens_ptr,
    out_ptr,
    lse_ptr,
    # 形状参数
    batch_size,
    seqlen_q,
    num_heads_q,
    head_size_k,
    head_dim_v,
    num_blocks,
    page_block_size,
    num_heads_k,
    max_num_blocks_per_seq,
    # 步长
    stride_q_batch,
    stride_q_seq,
    stride_q_head,
    stride_q_dim,
    stride_out_batch,
    stride_out_seq,
    stride_out_head,
    stride_out_dim,
    stride_lse_batch,
    stride_lse_head,
    stride_lse_seq,
    stride_k_cache_block,
    stride_k_cache_token,
    stride_k_cache_head,
    stride_k_cache_dim,
    stride_v_cache_block,
    stride_v_cache_token,
    stride_v_cache_head,
    stride_v_cache_dim,
    stride_block_table_batch,
    stride_block_table_block,
    # 其他参数
    softmax_scale,
    causal,
    BLOCK_SIZE: tl.constexpr,
    HEAD_DIM_K: tl.constexpr,          # 实际 head_size_k
    HEAD_DIM_V: tl.constexpr,          # 实际 head_dim_v
    HEAD_DIM_K_PAD: tl.constexpr,      # next_power_of_2(HEAD_DIM_K)
    HEAD_DIM_V_PAD: tl.constexpr,      # next_power_of_2(HEAD_DIM_V)
):
    # # 当前 program 处理的 (batch, head, token)
    # pid_b = tl.program_id(0)
    # pid_h = tl.program_id(1)
    # pid_t = tl.program_id(2)

    # # 序列长度
    # seq_len_k = tl.load(cache_seqlens_ptr + pid_b)
    # if seq_len_k == 0:
    #     # 没有可 attend 的 KV,输出全 0,LSE 为 inf
    #     out_offset = pid_b * stride_out_batch + pid_t * stride_out_seq + pid_h * stride_out_head
    #     for d in range(HEAD_DIM_V):
    #         tl.store(out_ptr + out_offset + d * stride_out_dim, 0.0)
    #     lse_offset = pid_b * stride_lse_batch + pid_h * stride_lse_head + pid_t * stride_lse_seq
    #     tl.store(lse_ptr + lse_offset, float('inf'))
    #     return
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    pid_t = tl.program_id(2)

    seq_len_k = tl.load(cache_seqlens_ptr + pid_b)
    if seq_len_k == 0:
        # 无 KV 可用
        out_offset = pid_b * stride_out_batch + pid_t * stride_out_seq + pid_h * stride_out_head
        out_offsets = tl.arange(0, HEAD_DIM_V)
        out_ptrs = out_ptr + out_offset + out_offsets * stride_out_dim
        tl.store(out_ptrs, 0.0, mask=out_offsets < HEAD_DIM_V)
        lse_offset = pid_b * stride_lse_batch + pid_h * stride_lse_head + pid_t * stride_lse_seq
        tl.store(lse_ptr + lse_offset, float('inf'))
        return

    num_heads_q_per_k = num_heads_q // num_heads_k
    key_head_idx = pid_h // num_heads_q_per_k

    q_offset = pid_b * stride_q_batch + pid_t * stride_q_seq + pid_h * stride_q_head
    offs_dk = tl.arange(0, HEAD_DIM_K_PAD)
    q_ptrs = q_ptr + q_offset + offs_dk * stride_q_dim
    mask_dk = offs_dk < HEAD_DIM_K
    q = tl.load(q_ptrs, mask=mask_dk, other=0.0)

    m_i = -float('inf')
    l_i = 0.0
    o_i = tl.zeros([HEAD_DIM_V], dtype=tl.float32)

    num_k_blocks = (seq_len_k + BLOCK_SIZE - 1) // BLOCK_SIZE
    offset_causal = seq_len_k - seqlen_q   # 注意 seqlen_q 是 kernel 参数

    for block_idx in range(num_k_blocks):
        physical_block = tl.load(block_table_ptr + pid_b * stride_block_table_batch + block_idx * stride_block_table_block)

        if block_idx == num_k_blocks - 1:
            cur_block_size = seq_len_k - block_idx * BLOCK_SIZE
            if cur_block_size == 0:
                cur_block_size = BLOCK_SIZE
        else:
            cur_block_size = BLOCK_SIZE

        # 加载 K
        k_block_ptr = k_cache_ptr + physical_block * stride_k_cache_block
        offs_t = tl.arange(0, BLOCK_SIZE)
        k_ptrs = (k_block_ptr + offs_t[:, None] * stride_k_cache_token +
                  key_head_idx * stride_k_cache_head + offs_dk[None, :] * stride_k_cache_dim)
        mask_k = (offs_t[:, None] < cur_block_size) & mask_dk[None, :]
        K_block = tl.load(k_ptrs, mask=mask_k, other=0.0)

        # 加载 V
        v_block_ptr = v_cache_ptr + physical_block * stride_v_cache_block
        offs_dv = tl.arange(0, HEAD_DIM_V)
        v_ptrs = (v_block_ptr + offs_t[:, None] * stride_v_cache_token +
                  key_head_idx * stride_v_cache_head + offs_dv[None, :] * stride_v_cache_dim)
        mask_v = offs_t[:, None] < cur_block_size
        V_block = tl.load(v_ptrs, mask=mask_v, other=0.0)

        # 计算分数
        scores = tl.sum(K_block * q[None, :], axis=1) * softmax_scale

        if causal:
            kv_pos = block_idx * BLOCK_SIZE + offs_t
            causal_mask = (kv_pos <= pid_t + offset_causal)
            mask = (offs_t < cur_block_size) & causal_mask
            scores = tl.where(mask, scores, -float('inf'))

        m_block = tl.max(scores, axis=0)

        # 只有当 block 至少有一个有效 token 时才更新状态
        if m_block != -float('inf'):
            m_new = tl.maximum(m_i, m_block)
            exp_scores = tl.exp(scores - m_new)
            l_i = l_i * tl.exp(m_i - m_new) + tl.sum(exp_scores, axis=0)
            o_i = o_i * tl.exp(m_i - m_new) + tl.sum(exp_scores[:, None] * V_block, axis=0)
            m_i = m_new

    # 最终输出
    if l_i == 0.0:
        out_val = tl.zeros([HEAD_DIM_V], dtype=tl.float32)
        lse_val = float('inf')
    else:
        out_val = o_i / l_i
        lse_val = m_i + tl.log(l_i)

    # 写回
    out_offset = pid_b * stride_out_batch + pid_t * stride_out_seq + pid_h * stride_out_head
    out_offsets = tl.arange(0, HEAD_DIM_V)
    out_ptrs = out_ptr + out_offset + out_offsets * stride_out_dim
    tl.store(out_ptrs, out_val, mask=out_offsets < HEAD_DIM_V)

    lse_offset = pid_b * stride_lse_batch + pid_h * stride_lse_head + pid_t * stride_lse_seq
    tl.store(lse_ptr + lse_offset, lse_val)


kernels = {}


def flash_mla_with_kvcache_triton(
    q: torch.Tensor,                     # [batch, seqlen_q, num_heads_q, head_size_k]
    k_cache: torch.Tensor,               # [num_blocks, page_block_size, num_heads_k, head_size_k]
    v_cache: torch.Tensor,               # [num_blocks, page_block_size, num_heads_k, head_dim_v]
    block_table: torch.Tensor,           # [batch, max_num_blocks_per_seq]
    cache_seqlens: torch.Tensor,         # [batch]
    head_dim_v: int,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Triton 实现的 Flash MLA 注意力,支持分页 KV cache 和因果掩码。
    返回 (output, lse)
    """
    # 维度断言
    assert q.dim() == 4, "q must be 4D"
    assert k_cache.dim() == 4, "k_cache must be 4D"
    assert v_cache.dim() == 4, "v_cache must be 4D"
    assert block_table.dim() == 2, "block_table must be 2D"
    assert cache_seqlens.dim() == 1, "cache_seqlens must be 1D"
    assert k_cache.shape[2] == v_cache.shape[2], "num_heads_k mismatch"
    assert k_cache.shape[3] == q.shape[3], "head_size_k mismatch"
    assert v_cache.shape[3] == head_dim_v, "head_dim_v mismatch"

    batch, seqlen_q, num_heads_q, head_size_k = q.shape
    num_blocks, page_block_size, num_heads_k, _ = k_cache.shape
    max_num_blocks_per_seq = block_table.shape[1]

    assert page_block_size == 64, "Only page_block_size=64 is supported"
    if softmax_scale is None:
        softmax_scale = 1.0 / math.sqrt(head_size_k)

    # 计算填充后的维度(2 的幂)
    HEAD_DIM_K_PAD = triton.next_power_of_2(head_size_k)
    HEAD_DIM_V_PAD = triton.next_power_of_2(head_dim_v)

    # 输出张量
    out = torch.empty((batch, seqlen_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device)
    lse = torch.empty((batch, num_heads_q, seqlen_q), dtype=torch.float32, device=q.device)

    # 步长
    stride_q_batch = q.stride(0)
    stride_q_seq = q.stride(1)
    stride_q_head = q.stride(2)
    stride_q_dim = q.stride(3)

    stride_out_batch = out.stride(0)
    stride_out_seq = out.stride(1)
    stride_out_head = out.stride(2)
    stride_out_dim = out.stride(3)

    stride_lse_batch = lse.stride(0)
    stride_lse_head = lse.stride(1)
    stride_lse_seq = lse.stride(2)

    stride_k_cache_block = k_cache.stride(0)
    stride_k_cache_token = k_cache.stride(1)
    stride_k_cache_head = k_cache.stride(2)
    stride_k_cache_dim = k_cache.stride(3)

    stride_v_cache_block = v_cache.stride(0)
    stride_v_cache_token = v_cache.stride(1)
    stride_v_cache_head = v_cache.stride(2)
    stride_v_cache_dim = v_cache.stride(3)

    stride_block_table_batch = block_table.stride(0)
    stride_block_table_block = block_table.stride(1)

    BLOCK_SIZE = page_block_size  # 64
    num_warps = max(1, (head_dim_v + 31) // 32)
    num_stages = 4

    grid = (batch, num_heads_q, seqlen_q)

    kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))

    if kernel is None:
        # 编译并调用 kernel(传递填充后的维度作为 constexpr)
        # 编译并调用 kernel(传递填充后的维度作为 constexpr,使用位置参数)
        kernel = _flash_mla_kernel.warmup(
            q, k_cache, v_cache, block_table, cache_seqlens, out, lse,
            batch, seqlen_q, num_heads_q, head_size_k, head_dim_v,
            num_blocks, page_block_size, num_heads_k, max_num_blocks_per_seq,
            stride_q_batch, stride_q_seq, stride_q_head, stride_q_dim,
            stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim,
            stride_lse_batch, stride_lse_head, stride_lse_seq,
            stride_k_cache_block, stride_k_cache_token, stride_k_cache_head, stride_k_cache_dim,
            stride_v_cache_block, stride_v_cache_token, stride_v_cache_head, stride_v_cache_dim,
            stride_block_table_batch, stride_block_table_block,
            softmax_scale, causal,
            BLOCK_SIZE, head_size_k, head_dim_v, HEAD_DIM_K_PAD, HEAD_DIM_V_PAD,
            num_warps=num_warps, num_stages=num_stages, grid=(1,)
        )
        kernels[BLOCK_SIZE] = (kernel, num_programs)


    kernel[(grid)](
        q, k_cache, v_cache, block_table, cache_seqlens, out, lse,
        batch, seqlen_q, num_heads_q, head_size_k, head_dim_v,
        num_blocks, page_block_size, num_heads_k, max_num_blocks_per_seq,
        stride_q_batch, stride_q_seq, stride_q_head, stride_q_dim,
        stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim,
        stride_lse_batch, stride_lse_head, stride_lse_seq,
        stride_k_cache_block, stride_k_cache_token, stride_k_cache_head, stride_k_cache_dim,
        stride_v_cache_block, stride_v_cache_token, stride_v_cache_head, stride_v_cache_dim,
        stride_block_table_batch, stride_block_table_block,
        softmax_scale, causal,
        BLOCK_SIZE, head_size_k, head_dim_v, HEAD_DIM_K_PAD, HEAD_DIM_V_PAD,
    )
    return out, lse


# 以下 wrapper 和 get_mla_metadata 保持不变
def flash_mla_with_kvcache_triton_interface(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata=None,
    num_splits=None,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
    attn_sink: Optional[torch.Tensor] = None,
    extra_k_cache: Optional[torch.Tensor] = None,
    extra_indices_in_kvcache: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
    extra_topk_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    if indices is None:
        head_dim_v_int = head_dim_v.item() if isinstance(head_dim_v, torch.Tensor) else head_dim_v
        v_cache = k_cache[..., :head_dim_v_int]
        out, lse = flash_mla_with_kvcache_triton(
            q, k_cache, v_cache, block_table, cache_seqlens, head_dim_v, softmax_scale, causal
        )
        return out, lse
    else:
        raise NotImplementedError("Sparse attention is not implemented in Triton version")



@dataclasses.dataclass
class FlashMLASchedMeta:
    """
    A class that stores the tile scheduler metadata of FlashMLA
    """

    @dataclasses.dataclass
    class Config:
        b: int
        s_q: int
        h_q: int
        page_block_size: int
        h_k: int

        causal: bool
        is_fp8_kvcache: bool
        topk: Optional[int]

        extra_page_block_size: Optional[int]
        extra_topk: Optional[int]

    have_initialized: bool = False

    config: Optional[Config] = None

    tile_scheduler_metadata: Optional[torch.Tensor] = None   # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    num_splits: Optional[torch.Tensor] = None                # (1), dtype torch.int32.


def get_mla_metadata(
    *args,
    **kwargs
) -> Tuple[FlashMLASchedMeta, None]:
    """
    Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache.

    Arguments:
        This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface.

    Return:
        A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful.
    """
    return FlashMLASchedMeta(), None


def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: Optional[torch.Tensor],
    cache_seqlens: Optional[torch.Tensor],
    head_dim_v: int,
    tile_scheduler_metadata: FlashMLASchedMeta,
    num_splits: None = None,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
    attn_sink: Optional[torch.Tensor] = None,
    extra_k_cache: Optional[torch.Tensor] = None,
    extra_indices_in_kvcache: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
    extra_topk_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
                Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
                The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
        cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
        head_dim_v: Head_dim of v. Must be 512
        sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
        num_splits_placeholder: must be "None" (to be compatible with the old interface).
        softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
        causal: bool. Whether to apply causal attention mask. Only valid for dense attention
        is_fp8_kvcache: bool.
        indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled.
                    Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block),
                    where t is the k-th token of the j-th q-sequence in the i-th batch.
        attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0.
        extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively.
        topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking.
    
    For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2:
        head_dim should be 576 while head_dim_v should be 512.
        In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as:
            - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
            - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
            - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
            - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.

    Return:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    sched_meta = tile_scheduler_metadata
    indices_in_kvcache = indices
    assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
    assert num_splits is None, "num_splits must be None"

    topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None
    extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None
    extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)

    if not sched_meta.have_initialized:
        # Sanity check. We only perform sanity check during the first invocation to save CPU time.
        if indices_in_kvcache is not None:
            assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)"
            
        # Initialize the tile scheduler metadata during the first invocation.
        sched_meta.have_initialized = True
        sched_meta.config = FlashMLASchedMeta.Config(
            q.shape[0],
            q.shape[1],
            q.shape[2],
            k_cache.shape[1],
            k_cache.shape[2],

            causal,
            is_fp8_kvcache,
            topk,

            extra_k_page_block_size,
            extra_topk,
        )
    else:
        # Check whether the input arguments are consistent with sched_meta
        helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
        assert sched_meta.config is not None
        assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
        assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
        assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
        assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
        assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
        assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
        assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
        assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg
        assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg
        assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg

    if topk is not None:
        # Sparse attention
        assert not causal, "causal must be False when sparse attention is enabled"
        assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled"
        out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
            q, k_cache, indices_in_kvcache, topk_length, attn_sink,
            sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
            extra_k_cache, extra_indices_in_kvcache, extra_topk_length,
            head_dim_v, softmax_scale
        )
    else:
        # Dense attention
        assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
        assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
        # out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(
        #     q, k_cache, head_dim_v,
        #     cache_seqlens, block_table,
        #     softmax_scale, causal,
        #     sched_meta.tile_scheduler_metadata, sched_meta.num_splits
        # )

        out, lse = flash_mla_with_kvcache_triton_interface(
            q, k_cache, block_table, cache_seqlens, head_dim_v,
            sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
            softmax_scale, causal
        )
    # sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
    # sched_meta.num_splits = new_num_splits
    return (out, lse)


def flash_mla_sparse_fwd(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
    attn_sink: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Sparse attention prefill kernel

    Args:
        q: [s_q, h_q, d_qk], bfloat16
        kv: [s_kv, h_kv, d_qk], bfloat16
        indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
        sm_scale: float
        d_v: The dimension of value vectors. Can only be 512
        attn_sink: optional, [h_q], float32.
            If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)).
            +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros).
            This argument has no effect on lse and max_logits.
        topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices).
            In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation.

    Returns:
        (output, max_logits, lse)
        Please refer to tests/ref.py for the precise definitions of these parameters.
        - output: [s_q, h_q, d_v], bfloat16
        - max_logits:  [s_q, h_q], float
        - lse: [s_q, h_q], float, log-sum-exp of attention scores
    """
    results = flash_mla_cuda.sparse_prefill_fwd(
        q, kv, indices, sm_scale, d_v, attn_sink, topk_length
    )
    return results

def get_mla_decoding_metadata_dense_fp8(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Returns:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
    return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k)

def flash_mla_with_kvcache_fp8(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    support 1) qkv fp8 e4m3 gfx938
            2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
            descale_q descale_k only support 1
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
        descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.

    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        descale_q,
        descale_k
    )
    return out, softmax_lse

def flash_mla_with_kvcache_fp8_with_cat(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
            2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
            3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
            4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
            descale_q descale_k only support 1
    Arguments:
        q_nope: (batch_size, seq_len_q, num_heads_q, 512).
        q_pe: (batch_size, seq_len_q, num_heads_q, 64).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
        descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.

    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        descale_q,
        descale_k
    )
    return out, softmax_lse

def flash_mla_with_kvcache_q_nope_pe(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits
    )
    return out, softmax_lse

def flash_mla_with_kvcache_quantization(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    k_scale = None,
    kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        k_scale: {1, torch.float32}, tensor shape is 1
        kv_cache_dtype: "only support fp8_e5m2"
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        k_scale,
        kv_cache_dtype
    )
    return out, softmax_lse

def flash_mla_with_kvcache_quantization_q_nope_pe(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    k_scale = None,
    kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        k_scale: {1, torch.float32}, tensor shape is 1
        kv_cache_dtype: "only support fp8_e5m2"
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        k_scale,
        kv_cache_dtype
    )
    return out, softmax_lse






# def flash_mla_with_kvcache_qkvfp8(
#     q: torch.Tensor,
#     k_cache: torch.Tensor,
#     block_table: Optional[torch.Tensor],
#     cache_seqlens: Optional[torch.Tensor],
#     head_dim_v: int,
#     tile_scheduler_metadata: FlashMLASchedMeta,
#     num_splits: None = None,
#     softmax_scale: Optional[float] = None,
#     causal: bool = False,
#     descale_q: Optional[torch.Tensor] = None,
#     descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     """
#     Arguments:
#         q: (batch_size, seq_len_q, num_heads_q, head_dim).
#         k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
#         block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
#         cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
#         head_dim_v: Head_dim of v. Must be 512
#         sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
#         num_splits_placeholder: must be "None" (to be compatible with the old interface).
#         softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
#         causal: bool. Whether to apply causal attention mask. Only valid for dense attention
#         descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
#         descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
#     Return:
#         out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
#         softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
#     """
#     sched_meta = tile_scheduler_metadata
#     assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
#     assert num_splits is None, "num_splits must be None"

#     if softmax_scale is None:
#         softmax_scale = q.shape[-1] ** (-0.5)

#     if not sched_meta.have_initialized:
#         # Initialize the tile scheduler metadata during the first invocation.
#         sched_meta.have_initialized = True
#         sched_meta.config = FlashMLASchedMeta.Config(
#             q.shape[0],
#             q.shape[1],
#             q.shape[2],
#             k_cache.shape[1],
#             k_cache.shape[2],
#             causal,
#             False,
#             0,
#             0,
#             0
#         )
#     else:
#         # Check whether the input arguments are consistent with sched_meta
#         helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
#         assert sched_meta.config is not None
#         assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
#         assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
#         assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
#         assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
#         assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
#         assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
#         assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
 

#     # Dense attention
#     assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
#     out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
#         q, k_cache, head_dim_v,
#         cache_seqlens, block_table,
#         softmax_scale, causal,
#         sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
#         descale_q, descale_k
#     )
#     sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
#     sched_meta.num_splits = new_num_splits
#     return (out, lse)

# def flash_mla_with_kvcache_kvfp8(
#     q: torch.Tensor,
#     k_cache: torch.Tensor,
#     block_table: Optional[torch.Tensor],
#     cache_seqlens: Optional[torch.Tensor],
#     head_dim_v: int,
#     tile_scheduler_metadata: FlashMLASchedMeta,
#     num_splits: None = None,
#     softmax_scale: Optional[float] = None,
#     causal: bool = False,
#     descale_q: Optional[torch.Tensor] = None,
#     descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     """
#     Arguments:
#         q: (batch_size, seq_len_q, num_heads_q, head_dim).
#         k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
#         block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
#         cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
#         head_dim_v: Head_dim of v. Must be 512
#         sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
#         num_splits_placeholder: must be "None" (to be compatible with the old interface).
#         softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
#         causal: bool. Whether to apply causal attention mask. Only valid for dense attention
#         descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
#         descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
#     Return:
#         out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
#         softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
#     """
#     sched_meta = tile_scheduler_metadata
#     assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
#     assert num_splits is None, "num_splits must be None"

#     if softmax_scale is None:
#         softmax_scale = q.shape[-1] ** (-0.5)

#     if not sched_meta.have_initialized:
#         # Initialize the tile scheduler metadata during the first invocation.
#         sched_meta.have_initialized = True
#         sched_meta.config = FlashMLASchedMeta.Config(
#             q.shape[0],
#             q.shape[1],
#             q.shape[2],
#             k_cache.shape[1],
#             k_cache.shape[2],
#             causal,
#             False,
#             0,
#             0,
#             0
#         )
#     else:
#         # Check whether the input arguments are consistent with sched_meta
#         helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
#         assert sched_meta.config is not None
#         assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
#         assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
#         assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
#         assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
#         assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
#         assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
#         assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
 

#     # Dense attention
#     assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
#     out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
#         q, k_cache, head_dim_v,
#         cache_seqlens, block_table,
#         softmax_scale, causal,
#         sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
#         descale_q, descale_k
#     )
#     sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
#     sched_meta.num_splits = new_num_splits
#     return (out, lse)