flash_mla_interface.py 33.9 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
from typing import Optional, Tuple
2
import dataclasses
Jiashi Li's avatar
Jiashi Li committed
3
4
5

import torch

6
import flash_mla.cuda as flash_mla_cuda
Jiashi Li's avatar
Jiashi Li committed
7

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclasses.dataclass
class FlashMLASchedMeta:
    """
    A class that stores the tile scheduler metadata of FlashMLA
    """

    @dataclasses.dataclass
    class Config:
        b: int
        s_q: int
        h_q: int
        page_block_size: int
        h_k: int

        causal: bool
        is_fp8_kvcache: bool
        topk: Optional[int]

        extra_page_block_size: Optional[int]
        extra_topk: Optional[int]

    have_initialized: bool = False

    config: Optional[Config] = None

    tile_scheduler_metadata: Optional[torch.Tensor] = None   # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    num_splits: Optional[torch.Tensor] = None                # (1), dtype torch.int32.


Jiashi Li's avatar
Jiashi Li committed
37
def get_mla_metadata(
38
39
40
    *args,
    **kwargs
) -> Tuple[FlashMLASchedMeta, None]:
Jiashi Li's avatar
Jiashi Li committed
41
    """
42
43
    Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache.

Jiashi Li's avatar
Jiashi Li committed
44
    Arguments:
45
        This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface.
Jiashi Li's avatar
Jiashi Li committed
46

47
48
    Return:
        A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful.
Jiashi Li's avatar
Jiashi Li committed
49
    """
50
    return FlashMLASchedMeta(), None
Jiashi Li's avatar
Jiashi Li committed
51
52


53
def flash_mla_with_kvcache(
Jiashi Li's avatar
Jiashi Li committed
54
55
    q: torch.Tensor,
    k_cache: torch.Tensor,
56
57
    block_table: Optional[torch.Tensor],
    cache_seqlens: Optional[torch.Tensor],
Jiashi Li's avatar
Jiashi Li committed
58
    head_dim_v: int,
59
60
    tile_scheduler_metadata: FlashMLASchedMeta,
    num_splits: None = None,
Jiashi Li's avatar
Jiashi Li committed
61
62
    softmax_scale: Optional[float] = None,
    causal: bool = False,
63
64
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
65
66
67
68
69
    attn_sink: Optional[torch.Tensor] = None,
    extra_k_cache: Optional[torch.Tensor] = None,
    extra_indices_in_kvcache: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
    extra_topk_length: Optional[torch.Tensor] = None
Jiashi Li's avatar
Jiashi Li committed
70
71
72
73
74
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Shengyu Liu's avatar
nits  
Shengyu Liu committed
75
                Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
76
77
78
79
80
81
82
83
84
85
86
87
88
                The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
        cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
        head_dim_v: Head_dim of v. Must be 512
        sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
        num_splits_placeholder: must be "None" (to be compatible with the old interface).
        softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
        causal: bool. Whether to apply causal attention mask. Only valid for dense attention
        is_fp8_kvcache: bool.
        indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled.
                    Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block),
                    where t is the k-th token of the j-th q-sequence in the i-th batch.
        attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0.
Shengyu Liu's avatar
nits  
Shengyu Liu committed
89
        extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively.
90
91
92
93
94
95
96
97
98
        topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking.
    
    For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2:
        head_dim should be 576 while head_dim_v should be 512.
        In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as:
            - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
            - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
            - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
            - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.
Jiashi Li's avatar
Jiashi Li committed
99

100
    Return:
Jiashi Li's avatar
Jiashi Li committed
101
102
103
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
104
105
106
107
108
109
110
111
    sched_meta = tile_scheduler_metadata
    indices_in_kvcache = indices
    assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
    assert num_splits is None, "num_splits must be None"

    topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None
    extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None
    extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None
Jiashi Li's avatar
Jiashi Li committed
112
113
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

    if not sched_meta.have_initialized:
        # Sanity check. We only perform sanity check during the first invocation to save CPU time.
        if indices_in_kvcache is not None:
            assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)"
            
        # Initialize the tile scheduler metadata during the first invocation.
        sched_meta.have_initialized = True
        sched_meta.config = FlashMLASchedMeta.Config(
            q.shape[0],
            q.shape[1],
            q.shape[2],
            k_cache.shape[1],
            k_cache.shape[2],

            causal,
            is_fp8_kvcache,
            topk,

            extra_k_page_block_size,
            extra_topk,
        )
    else:
        # Check whether the input arguments are consistent with sched_meta
        helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
        assert sched_meta.config is not None
        assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
        assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
        assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
        assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
        assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
        assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
        assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
        assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg
        assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg
        assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg

    if topk is not None:
        # Sparse attention
        assert not causal, "causal must be False when sparse attention is enabled"
        assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled"
        out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
            q, k_cache, indices_in_kvcache, topk_length, attn_sink,
            sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
            extra_k_cache, extra_indices_in_kvcache, extra_topk_length,
            head_dim_v, softmax_scale
        )
    else:
        # Dense attention
        assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
        assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
        out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(
            q, k_cache, head_dim_v,
            cache_seqlens, block_table,
            softmax_scale, causal,
            sched_meta.tile_scheduler_metadata, sched_meta.num_splits
        )
    sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
    sched_meta.num_splits = new_num_splits
    return (out, lse)
174
175


176
177
178
179
180
181
def flash_mla_sparse_fwd(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
182
183
    attn_sink: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
184
185
186
187
188
189
190
191
192
193
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Sparse attention prefill kernel

    Args:
        q: [s_q, h_q, d_qk], bfloat16
        kv: [s_kv, h_kv, d_qk], bfloat16
        indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
        sm_scale: float
        d_v: The dimension of value vectors. Can only be 512
194
195
196
197
198
199
        attn_sink: optional, [h_q], float32.
            If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)).
            +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros).
            This argument has no effect on lse and max_logits.
        topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices).
            In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation.
200
201
202

    Returns:
        (output, max_logits, lse)
203
        Please refer to tests/ref.py for the precise definitions of these parameters.
204
205
        - output: [s_q, h_q, d_v], bfloat16
        - max_logits:  [s_q, h_q], float
206
        - lse: [s_q, h_q], float, log-sum-exp of attention scores
207
208
    """
    results = flash_mla_cuda.sparse_prefill_fwd(
209
        q, kv, indices, sm_scale, d_v, attn_sink, topk_length
210
211
212
    )
    return results

zhanghj2's avatar
zhanghj2 committed
213
214
215
def get_mla_decoding_metadata_dense_fp8(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
216
217
    num_heads_k: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
zhanghj2's avatar
zhanghj2 committed
218
219
220
221
222
223
224
225
226
227
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Returns:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
228
    return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k)
229

zhanghj2's avatar
zhanghj2 committed
230
231
232


def flash_mla_with_kvcache_quantization(
zhanghj2's avatar
zhanghj2 committed
233
234
    q: torch.Tensor,
    k_cache: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
235
236
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
237
    head_dim_v: int,
zhanghj2's avatar
zhanghj2 committed
238
239
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
240
241
    softmax_scale: Optional[float] = None,
    causal: bool = False,
zhanghj2's avatar
zhanghj2 committed
242
243
    k_scale = None,
    kv_cache_dtype = None
zhanghj2's avatar
zhanghj2 committed
244
245
246
247
248
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
zhanghj2's avatar
zhanghj2 committed
249
250
251
252
253
254
255
256
257
258
259
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        k_scale: {1, torch.float32}, tensor shape is 1
        kv_cache_dtype: "only support fp8_e4m3"
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
zhanghj2's avatar
zhanghj2 committed
260
261
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
zhanghj2's avatar
zhanghj2 committed
262
    assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
zhanghj2's avatar
zhanghj2 committed
263
264
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        k_scale,
        kv_cache_dtype
    )
    return out, softmax_lse
zhanghj2's avatar
zhanghj2 committed
280

zhanghj2's avatar
zhanghj2 committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def flash_mla_with_kvcache_q_nope_pe(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
zhanghj2's avatar
zhanghj2 committed
304

zhanghj2's avatar
zhanghj2 committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits
zhanghj2's avatar
zhanghj2 committed
323
    )
zhanghj2's avatar
zhanghj2 committed
324
    return out, softmax_lse
zhanghj2's avatar
zhanghj2 committed
325

zhanghj2's avatar
zhanghj2 committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
def flash_mla_with_kvcache_quantization_q_nope_pe(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    k_scale = None,
    kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        k_scale: {1, torch.float32}, tensor shape is 1
        kv_cache_dtype: "only support fp8_e4m3"
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        k_scale,
        kv_cache_dtype
    )
    return out, softmax_lse

def flash_mla_with_kvcache_q_nope_pe(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits
    )
    return out, softmax_lse

def flash_mla_with_kvcache_quantization_q_nope_pe(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    k_scale = None,
    kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        k_scale: {1, torch.float32}, tensor shape is 1
        kv_cache_dtype: "only support fp8_e4m3"
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        k_scale,
        kv_cache_dtype
    )
    return out, softmax_lse


def flash_mla_with_kvcache_fp8(
zhanghj2's avatar
zhanghj2 committed
475
476
    q: torch.Tensor,
    k_cache: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
477
478
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
479
    head_dim_v: int,
zhanghj2's avatar
zhanghj2 committed
480
481
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
482
483
484
485
486
487
488
489
490
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
zhanghj2's avatar
zhanghj2 committed
491
492
493
494
495
496
497
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
zhanghj2's avatar
zhanghj2 committed
498
499
        descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
        descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
zhanghj2's avatar
zhanghj2 committed
500
501
502

    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
zhanghj2's avatar
zhanghj2 committed
503
504
505
506
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        descale_q,
        descale_k
    )
    return out, softmax_lse
zhanghj2's avatar
zhanghj2 committed
522

zhanghj2's avatar
zhanghj2 committed
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
def flash_mla_with_kvcache_fp8_with_cat(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q_nope: (batch_size, seq_len_q, num_heads_q, 512).
        q_pe: (batch_size, seq_len_q, num_heads_q, 64).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
        descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
zhanghj2's avatar
zhanghj2 committed
551

zhanghj2's avatar
zhanghj2 committed
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        descale_q,
        descale_k
zhanghj2's avatar
zhanghj2 committed
572
    )
zhanghj2's avatar
zhanghj2 committed
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
    return out, softmax_lse



# def flash_mla_with_kvcache_qkvfp8(
#     q: torch.Tensor,
#     k_cache: torch.Tensor,
#     block_table: Optional[torch.Tensor],
#     cache_seqlens: Optional[torch.Tensor],
#     head_dim_v: int,
#     tile_scheduler_metadata: FlashMLASchedMeta,
#     num_splits: None = None,
#     softmax_scale: Optional[float] = None,
#     causal: bool = False,
#     descale_q: Optional[torch.Tensor] = None,
#     descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     """
#     Arguments:
#         q: (batch_size, seq_len_q, num_heads_q, head_dim).
#         k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
#         block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
#         cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
#         head_dim_v: Head_dim of v. Must be 512
#         sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
#         num_splits_placeholder: must be "None" (to be compatible with the old interface).
#         softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
#         causal: bool. Whether to apply causal attention mask. Only valid for dense attention
#         descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
#         descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
#     Return:
#         out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
#         softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
#     """
#     sched_meta = tile_scheduler_metadata
#     assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
#     assert num_splits is None, "num_splits must be None"

#     if softmax_scale is None:
#         softmax_scale = q.shape[-1] ** (-0.5)

#     if not sched_meta.have_initialized:
#         # Initialize the tile scheduler metadata during the first invocation.
#         sched_meta.have_initialized = True
#         sched_meta.config = FlashMLASchedMeta.Config(
#             q.shape[0],
#             q.shape[1],
#             q.shape[2],
#             k_cache.shape[1],
#             k_cache.shape[2],
#             causal,
#             False,
#             0,
#             0,
#             0
#         )
#     else:
#         # Check whether the input arguments are consistent with sched_meta
#         helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
#         assert sched_meta.config is not None
#         assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
#         assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
#         assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
#         assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
#         assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
#         assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
#         assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
 

#     # Dense attention
#     assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
#     out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
#         q, k_cache, head_dim_v,
#         cache_seqlens, block_table,
#         softmax_scale, causal,
#         sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
#         descale_q, descale_k
#     )
#     sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
#     sched_meta.num_splits = new_num_splits
#     return (out, lse)

# def flash_mla_with_kvcache_kvfp8(
#     q: torch.Tensor,
#     k_cache: torch.Tensor,
#     block_table: Optional[torch.Tensor],
#     cache_seqlens: Optional[torch.Tensor],
#     head_dim_v: int,
#     tile_scheduler_metadata: FlashMLASchedMeta,
#     num_splits: None = None,
#     softmax_scale: Optional[float] = None,
#     causal: bool = False,
#     descale_q: Optional[torch.Tensor] = None,
#     descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     """
#     Arguments:
#         q: (batch_size, seq_len_q, num_heads_q, head_dim).
#         k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
#         block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
#         cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
#         head_dim_v: Head_dim of v. Must be 512
#         sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
#         num_splits_placeholder: must be "None" (to be compatible with the old interface).
#         softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
#         causal: bool. Whether to apply causal attention mask. Only valid for dense attention
#         descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
#         descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
#     Return:
#         out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
#         softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
#     """
#     sched_meta = tile_scheduler_metadata
#     assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
#     assert num_splits is None, "num_splits must be None"

#     if softmax_scale is None:
#         softmax_scale = q.shape[-1] ** (-0.5)

#     if not sched_meta.have_initialized:
#         # Initialize the tile scheduler metadata during the first invocation.
#         sched_meta.have_initialized = True
#         sched_meta.config = FlashMLASchedMeta.Config(
#             q.shape[0],
#             q.shape[1],
#             q.shape[2],
#             k_cache.shape[1],
#             k_cache.shape[2],
#             causal,
#             False,
#             0,
#             0,
#             0
#         )
#     else:
#         # Check whether the input arguments are consistent with sched_meta
#         helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
#         assert sched_meta.config is not None
#         assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
#         assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
#         assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
#         assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
#         assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
#         assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
#         assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
 

#     # Dense attention
#     assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
#     out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
#         q, k_cache, head_dim_v,
#         cache_seqlens, block_table,
#         softmax_scale, causal,
#         sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
#         descale_q, descale_k
#     )
#     sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
#     sched_meta.num_splits = new_num_splits
#     return (out, lse)