flash_mla_interface.py 31.2 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
from typing import Optional, Tuple
2
import dataclasses
Jiashi Li's avatar
Jiashi Li committed
3
4
5

import torch

6
import flash_mla.cuda as flash_mla_cuda
Jiashi Li's avatar
Jiashi Li committed
7

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclasses.dataclass
class FlashMLASchedMeta:
    """
    A class that stores the tile scheduler metadata of FlashMLA
    """

    @dataclasses.dataclass
    class Config:
        b: int
        s_q: int
        h_q: int
        page_block_size: int
        h_k: int

        causal: bool
        is_fp8_kvcache: bool
        topk: Optional[int]

        extra_page_block_size: Optional[int]
        extra_topk: Optional[int]

    have_initialized: bool = False

    config: Optional[Config] = None

    tile_scheduler_metadata: Optional[torch.Tensor] = None   # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    num_splits: Optional[torch.Tensor] = None                # (1), dtype torch.int32.


Jiashi Li's avatar
Jiashi Li committed
37
def get_mla_metadata(
38
39
40
    *args,
    **kwargs
) -> Tuple[FlashMLASchedMeta, None]:
Jiashi Li's avatar
Jiashi Li committed
41
    """
42
43
    Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache.

Jiashi Li's avatar
Jiashi Li committed
44
    Arguments:
45
        This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface.
Jiashi Li's avatar
Jiashi Li committed
46

47
48
    Return:
        A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful.
Jiashi Li's avatar
Jiashi Li committed
49
    """
50
    return FlashMLASchedMeta(), None
Jiashi Li's avatar
Jiashi Li committed
51
52


53
def flash_mla_with_kvcache(
Jiashi Li's avatar
Jiashi Li committed
54
55
    q: torch.Tensor,
    k_cache: torch.Tensor,
56
57
    block_table: Optional[torch.Tensor],
    cache_seqlens: Optional[torch.Tensor],
Jiashi Li's avatar
Jiashi Li committed
58
    head_dim_v: int,
59
60
    tile_scheduler_metadata: FlashMLASchedMeta,
    num_splits: None = None,
Jiashi Li's avatar
Jiashi Li committed
61
62
    softmax_scale: Optional[float] = None,
    causal: bool = False,
63
64
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
65
66
67
68
69
    attn_sink: Optional[torch.Tensor] = None,
    extra_k_cache: Optional[torch.Tensor] = None,
    extra_indices_in_kvcache: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
    extra_topk_length: Optional[torch.Tensor] = None
Jiashi Li's avatar
Jiashi Li committed
70
71
72
73
74
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Shengyu Liu's avatar
nits  
Shengyu Liu committed
75
                Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
76
77
78
79
80
81
82
83
84
85
86
87
88
                The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
        cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
        head_dim_v: Head_dim of v. Must be 512
        sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
        num_splits_placeholder: must be "None" (to be compatible with the old interface).
        softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
        causal: bool. Whether to apply causal attention mask. Only valid for dense attention
        is_fp8_kvcache: bool.
        indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled.
                    Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block),
                    where t is the k-th token of the j-th q-sequence in the i-th batch.
        attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0.
Shengyu Liu's avatar
nits  
Shengyu Liu committed
89
        extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively.
90
91
92
93
94
95
96
97
98
        topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking.
    
    For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2:
        head_dim should be 576 while head_dim_v should be 512.
        In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as:
            - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
            - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
            - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
            - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.
Jiashi Li's avatar
Jiashi Li committed
99

100
    Return:
Jiashi Li's avatar
Jiashi Li committed
101
102
103
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
104
105
106
107
108
109
110
111
    sched_meta = tile_scheduler_metadata
    indices_in_kvcache = indices
    assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
    assert num_splits is None, "num_splits must be None"

    topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None
    extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None
    extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None
Jiashi Li's avatar
Jiashi Li committed
112
113
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    if not sched_meta.have_initialized:
        # Sanity check. We only perform sanity check during the first invocation to save CPU time.
        if indices_in_kvcache is not None:
            assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)"
            
        # Initialize the tile scheduler metadata during the first invocation.
        sched_meta.have_initialized = True
        sched_meta.config = FlashMLASchedMeta.Config(
            q.shape[0],
            q.shape[1],
            q.shape[2],
            k_cache.shape[1],
            k_cache.shape[2],

            causal,
            is_fp8_kvcache,
            topk,

            extra_k_page_block_size,
            extra_topk,
        )
    else:
        # Check whether the input arguments are consistent with sched_meta
        helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
        assert sched_meta.config is not None
        assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
        assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
        assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
        assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
        assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
        assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
        assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
        assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg
        assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg
        assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg

    if topk is not None:
        # Sparse attention
        assert not causal, "causal must be False when sparse attention is enabled"
shenzhe's avatar
shenzhe committed
154
155
156
157
        if not is_fp8_kvcache:
            assert k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
            if extra_k_cache is not None:
                assert extra_k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires extra_k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
            q, k_cache, indices_in_kvcache, topk_length, attn_sink,
            sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
            extra_k_cache, extra_indices_in_kvcache, extra_topk_length,
            head_dim_v, softmax_scale
        )
    else:
        # Dense attention
        assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
        assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
        out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(
            q, k_cache, head_dim_v,
            cache_seqlens, block_table,
            softmax_scale, causal,
            sched_meta.tile_scheduler_metadata, sched_meta.num_splits
        )
    sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
    sched_meta.num_splits = new_num_splits
    return (out, lse)
177
178


179
180
181
182
183
184
def flash_mla_sparse_fwd(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
185
186
    attn_sink: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
187
188
189
190
191
192
193
194
195
196
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Sparse attention prefill kernel

    Args:
        q: [s_q, h_q, d_qk], bfloat16
        kv: [s_kv, h_kv, d_qk], bfloat16
        indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
        sm_scale: float
        d_v: The dimension of value vectors. Can only be 512
197
198
199
200
201
202
        attn_sink: optional, [h_q], float32.
            If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)).
            +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros).
            This argument has no effect on lse and max_logits.
        topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices).
            In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation.
203
204
205

    Returns:
        (output, max_logits, lse)
206
        Please refer to tests/ref.py for the precise definitions of these parameters.
207
208
        - output: [s_q, h_q, d_v], bfloat16
        - max_logits:  [s_q, h_q], float
209
        - lse: [s_q, h_q], float, log-sum-exp of attention scores
210
211
    """
    results = flash_mla_cuda.sparse_prefill_fwd(
212
        q, kv, indices, sm_scale, d_v, attn_sink, topk_length
213
214
215
    )
    return results

zhanghj2's avatar
zhanghj2 committed
216
217
218
def get_mla_decoding_metadata_dense_fp8(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
219
220
    num_heads_k: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
zhanghj2's avatar
zhanghj2 committed
221
222
223
224
225
226
227
228
229
230
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Returns:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
231
    return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k)
232

zhanghj2's avatar
zhanghj2 committed
233
def flash_mla_with_kvcache_fp8(
zhanghj2's avatar
zhanghj2 committed
234
235
    q: torch.Tensor,
    k_cache: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
236
237
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
238
    head_dim_v: int,
zhanghj2's avatar
zhanghj2 committed
239
240
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
241
242
    softmax_scale: Optional[float] = None,
    causal: bool = False,
zhanghj2's avatar
zhanghj2 committed
243
244
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
zhanghj2's avatar
zhanghj2 committed
245
246
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
zhanghj2's avatar
zhanghj2 committed
247
248
249
    support 1) qkv fp8 e4m3 gfx938
            2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
            descale_q descale_k only support 1
zhanghj2's avatar
zhanghj2 committed
250
251
252
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
zhanghj2's avatar
zhanghj2 committed
253
254
255
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
zhanghj2's avatar
zhanghj2 committed
256
257
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
zhanghj2's avatar
zhanghj2 committed
258
259
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
zhanghj2's avatar
zhanghj2 committed
260
261
262
        descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
        descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.

zhanghj2's avatar
zhanghj2 committed
263
264
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
zhanghj2's avatar
zhanghj2 committed
265
266
267
268
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
269
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
zhanghj2's avatar
zhanghj2 committed
270
271
272
273
274
275
276
277
278
279
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
zhanghj2's avatar
zhanghj2 committed
280
281
        descale_q,
        descale_k
zhanghj2's avatar
zhanghj2 committed
282
    )
zhanghj2's avatar
zhanghj2 committed
283
    return out, softmax_lse
zhanghj2's avatar
zhanghj2 committed
284

zhanghj2's avatar
zhanghj2 committed
285
def flash_mla_with_kvcache_fp8_with_cat(
zhanghj2's avatar
zhanghj2 committed
286
287
288
289
290
291
292
293
294
295
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
zhanghj2's avatar
zhanghj2 committed
296
297
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
zhanghj2's avatar
zhanghj2 committed
298
299
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
zhanghj2's avatar
zhanghj2 committed
300
301
302
303
304
    support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
            2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
            3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
            4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
            descale_q descale_k only support 1
zhanghj2's avatar
zhanghj2 committed
305
    Arguments:
zhanghj2's avatar
zhanghj2 committed
306
307
        q_nope: (batch_size, seq_len_q, num_heads_q, 512).
        q_pe: (batch_size, seq_len_q, num_heads_q, 64).
zhanghj2's avatar
zhanghj2 committed
308
309
310
311
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
zhanghj2's avatar
zhanghj2 committed
312
313
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
zhanghj2's avatar
zhanghj2 committed
314
315
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
zhanghj2's avatar
zhanghj2 committed
316
317
318
        descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
        descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.

zhanghj2's avatar
zhanghj2 committed
319
320
321
322
323
324
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
325
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
zhanghj2's avatar
zhanghj2 committed
326
327
328
329
330
331
332
333
334
335
336
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
zhanghj2's avatar
zhanghj2 committed
337
338
        descale_q,
        descale_k
zhanghj2's avatar
zhanghj2 committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    )
    return out, softmax_lse

def flash_mla_with_kvcache_q_nope_pe(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits
    )
    return out, softmax_lse

zhanghj2's avatar
zhanghj2 committed
387
388
def flash_mla_with_kvcache_quantization(
    q: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    k_scale = None,
    kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
zhanghj2's avatar
zhanghj2 committed
407
408
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
zhanghj2's avatar
zhanghj2 committed
409
410
411
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        k_scale: {1, torch.float32}, tensor shape is 1
zhanghj2's avatar
zhanghj2 committed
412
        kv_cache_dtype: "only support fp8_e5m2"
zhanghj2's avatar
zhanghj2 committed
413
414
415
416
417
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
zhanghj2's avatar
zhanghj2 committed
418
419
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
420
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
zhanghj2's avatar
zhanghj2 committed
421
422
423
424
425
426
427
428
429
430
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
zhanghj2's avatar
zhanghj2 committed
431
432
        k_scale,
        kv_cache_dtype
zhanghj2's avatar
zhanghj2 committed
433
434
    )
    return out, softmax_lse
zhanghj2's avatar
zhanghj2 committed
435

zhanghj2's avatar
zhanghj2 committed
436
def flash_mla_with_kvcache_quantization_q_nope_pe(
zhanghj2's avatar
zhanghj2 committed
437
438
439
440
441
442
443
444
445
446
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
zhanghj2's avatar
zhanghj2 committed
447
448
    k_scale = None,
    kv_cache_dtype = None
zhanghj2's avatar
zhanghj2 committed
449
450
451
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
zhanghj2's avatar
zhanghj2 committed
452
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
zhanghj2's avatar
zhanghj2 committed
453
454
455
456
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
zhanghj2's avatar
zhanghj2 committed
457
458
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
zhanghj2's avatar
zhanghj2 committed
459
460
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
zhanghj2's avatar
zhanghj2 committed
461
462
        k_scale: {1, torch.float32}, tensor shape is 1
        kv_cache_dtype: "only support fp8_e5m2"
zhanghj2's avatar
zhanghj2 committed
463
464
465
466
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
zhanghj2's avatar
zhanghj2 committed
467
    assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
zhanghj2's avatar
zhanghj2 committed
468
469
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
470
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
zhanghj2's avatar
zhanghj2 committed
471
472
473
474
475
476
477
478
479
480
481
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
zhanghj2's avatar
zhanghj2 committed
482
483
        k_scale,
        kv_cache_dtype
zhanghj2's avatar
zhanghj2 committed
484
    )
zhanghj2's avatar
zhanghj2 committed
485
486
487
488
    return out, softmax_lse



zhanghj2's avatar
zhanghj2 committed
489
490
491



zhanghj2's avatar
zhanghj2 committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
# def flash_mla_with_kvcache_qkvfp8(
#     q: torch.Tensor,
#     k_cache: torch.Tensor,
#     block_table: Optional[torch.Tensor],
#     cache_seqlens: Optional[torch.Tensor],
#     head_dim_v: int,
#     tile_scheduler_metadata: FlashMLASchedMeta,
#     num_splits: None = None,
#     softmax_scale: Optional[float] = None,
#     causal: bool = False,
#     descale_q: Optional[torch.Tensor] = None,
#     descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     """
#     Arguments:
#         q: (batch_size, seq_len_q, num_heads_q, head_dim).
#         k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
#         block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
#         cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
#         head_dim_v: Head_dim of v. Must be 512
#         sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
#         num_splits_placeholder: must be "None" (to be compatible with the old interface).
#         softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
#         causal: bool. Whether to apply causal attention mask. Only valid for dense attention
#         descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
#         descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
#     Return:
#         out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
#         softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
#     """
#     sched_meta = tile_scheduler_metadata
#     assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
#     assert num_splits is None, "num_splits must be None"

#     if softmax_scale is None:
#         softmax_scale = q.shape[-1] ** (-0.5)

#     if not sched_meta.have_initialized:
#         # Initialize the tile scheduler metadata during the first invocation.
#         sched_meta.have_initialized = True
#         sched_meta.config = FlashMLASchedMeta.Config(
#             q.shape[0],
#             q.shape[1],
#             q.shape[2],
#             k_cache.shape[1],
#             k_cache.shape[2],
#             causal,
#             False,
#             0,
#             0,
#             0
#         )
#     else:
#         # Check whether the input arguments are consistent with sched_meta
#         helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
#         assert sched_meta.config is not None
#         assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
#         assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
#         assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
#         assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
#         assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
#         assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
#         assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
 

#     # Dense attention
#     assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
#     out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
#         q, k_cache, head_dim_v,
#         cache_seqlens, block_table,
#         softmax_scale, causal,
#         sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
#         descale_q, descale_k
#     )
#     sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
#     sched_meta.num_splits = new_num_splits
#     return (out, lse)

# def flash_mla_with_kvcache_kvfp8(
#     q: torch.Tensor,
#     k_cache: torch.Tensor,
#     block_table: Optional[torch.Tensor],
#     cache_seqlens: Optional[torch.Tensor],
#     head_dim_v: int,
#     tile_scheduler_metadata: FlashMLASchedMeta,
#     num_splits: None = None,
#     softmax_scale: Optional[float] = None,
#     causal: bool = False,
#     descale_q: Optional[torch.Tensor] = None,
#     descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     """
#     Arguments:
#         q: (batch_size, seq_len_q, num_heads_q, head_dim).
#         k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
#         block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
#         cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
#         head_dim_v: Head_dim of v. Must be 512
#         sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
#         num_splits_placeholder: must be "None" (to be compatible with the old interface).
#         softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
#         causal: bool. Whether to apply causal attention mask. Only valid for dense attention
#         descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
#         descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
#     Return:
#         out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
#         softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
#     """
#     sched_meta = tile_scheduler_metadata
#     assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
#     assert num_splits is None, "num_splits must be None"

#     if softmax_scale is None:
#         softmax_scale = q.shape[-1] ** (-0.5)

#     if not sched_meta.have_initialized:
#         # Initialize the tile scheduler metadata during the first invocation.
#         sched_meta.have_initialized = True
#         sched_meta.config = FlashMLASchedMeta.Config(
#             q.shape[0],
#             q.shape[1],
#             q.shape[2],
#             k_cache.shape[1],
#             k_cache.shape[2],
#             causal,
#             False,
#             0,
#             0,
#             0
#         )
#     else:
#         # Check whether the input arguments are consistent with sched_meta
#         helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
#         assert sched_meta.config is not None
#         assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
#         assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
#         assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
#         assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
#         assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
#         assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
#         assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
 

#     # Dense attention
#     assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
#     out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
#         q, k_cache, head_dim_v,
#         cache_seqlens, block_table,
#         softmax_scale, causal,
#         sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
#         descale_q, descale_k
#     )
#     sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
#     sched_meta.num_splits = new_num_splits
shenzhe's avatar
shenzhe committed
646
#     return (out, lse)