flash_mla_interface.py 31.5 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
from typing import Optional, Tuple
2
import dataclasses
Jiashi Li's avatar
Jiashi Li committed
3
4
5

import torch

6
import flash_mla.cuda as flash_mla_cuda
Jiashi Li's avatar
Jiashi Li committed
7

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclasses.dataclass
class FlashMLASchedMeta:
    """
    A class that stores the tile scheduler metadata of FlashMLA
    """

    @dataclasses.dataclass
    class Config:
        b: int
        s_q: int
        h_q: int
        page_block_size: int
        h_k: int

        causal: bool
        is_fp8_kvcache: bool
        topk: Optional[int]

        extra_page_block_size: Optional[int]
        extra_topk: Optional[int]

    have_initialized: bool = False

    config: Optional[Config] = None

    tile_scheduler_metadata: Optional[torch.Tensor] = None   # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    num_splits: Optional[torch.Tensor] = None                # (1), dtype torch.int32.


Jiashi Li's avatar
Jiashi Li committed
37
def get_mla_metadata(
38
39
40
    *args,
    **kwargs
) -> Tuple[FlashMLASchedMeta, None]:
Jiashi Li's avatar
Jiashi Li committed
41
    """
42
43
    Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache.

Jiashi Li's avatar
Jiashi Li committed
44
    Arguments:
45
        This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface.
Jiashi Li's avatar
Jiashi Li committed
46

47
48
    Return:
        A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful.
Jiashi Li's avatar
Jiashi Li committed
49
    """
50
    return FlashMLASchedMeta(), None
Jiashi Li's avatar
Jiashi Li committed
51
52


53
def flash_mla_with_kvcache(
Jiashi Li's avatar
Jiashi Li committed
54
55
    q: torch.Tensor,
    k_cache: torch.Tensor,
56
57
    block_table: Optional[torch.Tensor],
    cache_seqlens: Optional[torch.Tensor],
Jiashi Li's avatar
Jiashi Li committed
58
    head_dim_v: int,
59
    tile_scheduler_metadata: FlashMLASchedMeta,
60
    num_splits: Optional[torch.Tensor] = None,
Jiashi Li's avatar
Jiashi Li committed
61
62
    softmax_scale: Optional[float] = None,
    causal: bool = False,
63
64
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
65
66
67
68
69
    attn_sink: Optional[torch.Tensor] = None,
    extra_k_cache: Optional[torch.Tensor] = None,
    extra_indices_in_kvcache: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
    extra_topk_length: Optional[torch.Tensor] = None
Jiashi Li's avatar
Jiashi Li committed
70
71
72
73
74
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Shengyu Liu's avatar
nits  
Shengyu Liu committed
75
                Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
76
77
78
79
80
                The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
        cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
        head_dim_v: Head_dim of v. Must be 512
        sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
81
        num_splits: optional override for BF16 sparse decode. Other paths keep using sched_meta.
82
83
84
85
86
87
88
        softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
        causal: bool. Whether to apply causal attention mask. Only valid for dense attention
        is_fp8_kvcache: bool.
        indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled.
                    Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block),
                    where t is the k-th token of the j-th q-sequence in the i-th batch.
        attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0.
Shengyu Liu's avatar
nits  
Shengyu Liu committed
89
        extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively.
90
91
92
93
94
95
96
97
98
        topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking.
    
    For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2:
        head_dim should be 576 while head_dim_v should be 512.
        In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as:
            - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
            - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
            - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
            - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.
Jiashi Li's avatar
Jiashi Li committed
99

100
    Return:
Jiashi Li's avatar
Jiashi Li committed
101
102
103
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
104
105
106
107
108
109
110
    sched_meta = tile_scheduler_metadata
    indices_in_kvcache = indices
    assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"

    topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None
    extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None
    extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None
Jiashi Li's avatar
Jiashi Li committed
111
112
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    if not sched_meta.have_initialized:
        # Sanity check. We only perform sanity check during the first invocation to save CPU time.
        if indices_in_kvcache is not None:
            assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)"
            
        # Initialize the tile scheduler metadata during the first invocation.
        sched_meta.have_initialized = True
        sched_meta.config = FlashMLASchedMeta.Config(
            q.shape[0],
            q.shape[1],
            q.shape[2],
            k_cache.shape[1],
            k_cache.shape[2],

            causal,
            is_fp8_kvcache,
            topk,

            extra_k_page_block_size,
            extra_topk,
        )
    else:
        # Check whether the input arguments are consistent with sched_meta
        helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
        assert sched_meta.config is not None
        assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
        assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
        assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
        assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
        assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
        assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
        assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
        assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg
        assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg
        assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg

    if topk is not None:
        # Sparse attention
        assert not causal, "causal must be False when sparse attention is enabled"
shenzhe's avatar
shenzhe committed
153
154
155
156
        if not is_fp8_kvcache:
            assert k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
            if extra_k_cache is not None:
                assert extra_k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires extra_k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
157
158
159
        else:
            assert num_splits is None, "num_splits override is only supported by BF16 sparse decode"
        decode_num_splits = num_splits if num_splits is not None else sched_meta.num_splits
160
161
        out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
            q, k_cache, indices_in_kvcache, topk_length, attn_sink,
162
            sched_meta.tile_scheduler_metadata, decode_num_splits,
163
164
165
166
167
            extra_k_cache, extra_indices_in_kvcache, extra_topk_length,
            head_dim_v, softmax_scale
        )
    else:
        # Dense attention
168
        assert num_splits is None, "num_splits override is only supported by BF16 sparse decode"
169
170
171
172
173
174
175
176
177
178
179
        assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
        assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
        out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(
            q, k_cache, head_dim_v,
            cache_seqlens, block_table,
            softmax_scale, causal,
            sched_meta.tile_scheduler_metadata, sched_meta.num_splits
        )
    sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
    sched_meta.num_splits = new_num_splits
    return (out, lse)
180
181


182
183
184
185
186
187
def flash_mla_sparse_fwd(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
188
189
    attn_sink: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
190
191
192
193
194
195
196
197
198
199
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Sparse attention prefill kernel

    Args:
        q: [s_q, h_q, d_qk], bfloat16
        kv: [s_kv, h_kv, d_qk], bfloat16
        indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
        sm_scale: float
        d_v: The dimension of value vectors. Can only be 512
200
201
202
203
204
205
        attn_sink: optional, [h_q], float32.
            If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)).
            +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros).
            This argument has no effect on lse and max_logits.
        topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices).
            In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation.
206
207
208

    Returns:
        (output, max_logits, lse)
209
        Please refer to tests/ref.py for the precise definitions of these parameters.
210
211
        - output: [s_q, h_q, d_v], bfloat16
        - max_logits:  [s_q, h_q], float
212
        - lse: [s_q, h_q], float, log-sum-exp of attention scores
213
214
    """
    results = flash_mla_cuda.sparse_prefill_fwd(
215
        q, kv, indices, sm_scale, d_v, attn_sink, topk_length
216
217
218
    )
    return results

zhanghj2's avatar
zhanghj2 committed
219
220
221
def get_mla_decoding_metadata_dense_fp8(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
222
223
    num_heads_k: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
zhanghj2's avatar
zhanghj2 committed
224
225
226
227
228
229
230
231
232
233
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Returns:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
234
    return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k)
235

zhanghj2's avatar
zhanghj2 committed
236
def flash_mla_with_kvcache_fp8(
zhanghj2's avatar
zhanghj2 committed
237
238
    q: torch.Tensor,
    k_cache: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
239
240
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
241
    head_dim_v: int,
zhanghj2's avatar
zhanghj2 committed
242
243
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
244
245
    softmax_scale: Optional[float] = None,
    causal: bool = False,
zhanghj2's avatar
zhanghj2 committed
246
247
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
zhanghj2's avatar
zhanghj2 committed
248
249
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
zhanghj2's avatar
zhanghj2 committed
250
251
252
    support 1) qkv fp8 e4m3 gfx938
            2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
            descale_q descale_k only support 1
zhanghj2's avatar
zhanghj2 committed
253
254
255
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
zhanghj2's avatar
zhanghj2 committed
256
257
258
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
zhanghj2's avatar
zhanghj2 committed
259
260
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
zhanghj2's avatar
zhanghj2 committed
261
262
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
zhanghj2's avatar
zhanghj2 committed
263
264
265
        descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
        descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.

zhanghj2's avatar
zhanghj2 committed
266
267
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
zhanghj2's avatar
zhanghj2 committed
268
269
270
271
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
272
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
zhanghj2's avatar
zhanghj2 committed
273
274
275
276
277
278
279
280
281
282
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
zhanghj2's avatar
zhanghj2 committed
283
284
        descale_q,
        descale_k
zhanghj2's avatar
zhanghj2 committed
285
    )
zhanghj2's avatar
zhanghj2 committed
286
    return out, softmax_lse
zhanghj2's avatar
zhanghj2 committed
287

zhanghj2's avatar
zhanghj2 committed
288
def flash_mla_with_kvcache_fp8_with_cat(
zhanghj2's avatar
zhanghj2 committed
289
290
291
292
293
294
295
296
297
298
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
zhanghj2's avatar
zhanghj2 committed
299
300
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
zhanghj2's avatar
zhanghj2 committed
301
302
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
zhanghj2's avatar
zhanghj2 committed
303
304
305
306
307
    support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
            2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
            3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
            4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
            descale_q descale_k only support 1
zhanghj2's avatar
zhanghj2 committed
308
    Arguments:
zhanghj2's avatar
zhanghj2 committed
309
310
        q_nope: (batch_size, seq_len_q, num_heads_q, 512).
        q_pe: (batch_size, seq_len_q, num_heads_q, 64).
zhanghj2's avatar
zhanghj2 committed
311
312
313
314
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
zhanghj2's avatar
zhanghj2 committed
315
316
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
zhanghj2's avatar
zhanghj2 committed
317
318
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
zhanghj2's avatar
zhanghj2 committed
319
320
321
        descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
        descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.

zhanghj2's avatar
zhanghj2 committed
322
323
324
325
326
327
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
328
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
zhanghj2's avatar
zhanghj2 committed
329
330
331
332
333
334
335
336
337
338
339
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
zhanghj2's avatar
zhanghj2 committed
340
341
        descale_q,
        descale_k
zhanghj2's avatar
zhanghj2 committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    )
    return out, softmax_lse

def flash_mla_with_kvcache_q_nope_pe(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits
    )
    return out, softmax_lse

zhanghj2's avatar
zhanghj2 committed
390
391
def flash_mla_with_kvcache_quantization(
    q: torch.Tensor,
zhanghj2's avatar
zhanghj2 committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    k_scale = None,
    kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
zhanghj2's avatar
zhanghj2 committed
410
411
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
zhanghj2's avatar
zhanghj2 committed
412
413
414
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        k_scale: {1, torch.float32}, tensor shape is 1
zhanghj2's avatar
zhanghj2 committed
415
        kv_cache_dtype: "only support fp8_e5m2"
zhanghj2's avatar
zhanghj2 committed
416
417
418
419
420
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
zhanghj2's avatar
zhanghj2 committed
421
422
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
423
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
zhanghj2's avatar
zhanghj2 committed
424
425
426
427
428
429
430
431
432
433
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
zhanghj2's avatar
zhanghj2 committed
434
435
        k_scale,
        kv_cache_dtype
zhanghj2's avatar
zhanghj2 committed
436
437
    )
    return out, softmax_lse
zhanghj2's avatar
zhanghj2 committed
438

zhanghj2's avatar
zhanghj2 committed
439
def flash_mla_with_kvcache_quantization_q_nope_pe(
zhanghj2's avatar
zhanghj2 committed
440
441
442
443
444
445
446
447
448
449
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
zhanghj2's avatar
zhanghj2 committed
450
451
    k_scale = None,
    kv_cache_dtype = None
zhanghj2's avatar
zhanghj2 committed
452
453
454
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
zhanghj2's avatar
zhanghj2 committed
455
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
zhanghj2's avatar
zhanghj2 committed
456
457
458
459
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head dimension of v.
zhanghj2's avatar
zhanghj2 committed
460
461
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
zhanghj2's avatar
zhanghj2 committed
462
463
        softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
zhanghj2's avatar
zhanghj2 committed
464
465
        k_scale: {1, torch.float32}, tensor shape is 1
        kv_cache_dtype: "only support fp8_e5m2"
zhanghj2's avatar
zhanghj2 committed
466
467
468
469
    Returns:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
zhanghj2's avatar
zhanghj2 committed
470
    assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
zhanghj2's avatar
zhanghj2 committed
471
472
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
zhanghj2's avatar
zhanghj2 committed
473
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
zhanghj2's avatar
zhanghj2 committed
474
475
476
477
478
479
480
481
482
483
484
        q_nope,
        q_pe,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
zhanghj2's avatar
zhanghj2 committed
485
486
        k_scale,
        kv_cache_dtype
zhanghj2's avatar
zhanghj2 committed
487
    )
zhanghj2's avatar
zhanghj2 committed
488
489
490
491
    return out, softmax_lse



zhanghj2's avatar
zhanghj2 committed
492
493
494



zhanghj2's avatar
zhanghj2 committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
# def flash_mla_with_kvcache_qkvfp8(
#     q: torch.Tensor,
#     k_cache: torch.Tensor,
#     block_table: Optional[torch.Tensor],
#     cache_seqlens: Optional[torch.Tensor],
#     head_dim_v: int,
#     tile_scheduler_metadata: FlashMLASchedMeta,
#     num_splits: None = None,
#     softmax_scale: Optional[float] = None,
#     causal: bool = False,
#     descale_q: Optional[torch.Tensor] = None,
#     descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     """
#     Arguments:
#         q: (batch_size, seq_len_q, num_heads_q, head_dim).
#         k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
#         block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
#         cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
#         head_dim_v: Head_dim of v. Must be 512
#         sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
#         num_splits_placeholder: must be "None" (to be compatible with the old interface).
#         softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
#         causal: bool. Whether to apply causal attention mask. Only valid for dense attention
#         descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
#         descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
#     Return:
#         out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
#         softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
#     """
#     sched_meta = tile_scheduler_metadata
#     assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
#     assert num_splits is None, "num_splits must be None"

#     if softmax_scale is None:
#         softmax_scale = q.shape[-1] ** (-0.5)

#     if not sched_meta.have_initialized:
#         # Initialize the tile scheduler metadata during the first invocation.
#         sched_meta.have_initialized = True
#         sched_meta.config = FlashMLASchedMeta.Config(
#             q.shape[0],
#             q.shape[1],
#             q.shape[2],
#             k_cache.shape[1],
#             k_cache.shape[2],
#             causal,
#             False,
#             0,
#             0,
#             0
#         )
#     else:
#         # Check whether the input arguments are consistent with sched_meta
#         helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
#         assert sched_meta.config is not None
#         assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
#         assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
#         assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
#         assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
#         assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
#         assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
#         assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
 

#     # Dense attention
#     assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
#     out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
#         q, k_cache, head_dim_v,
#         cache_seqlens, block_table,
#         softmax_scale, causal,
#         sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
#         descale_q, descale_k
#     )
#     sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
#     sched_meta.num_splits = new_num_splits
#     return (out, lse)

# def flash_mla_with_kvcache_kvfp8(
#     q: torch.Tensor,
#     k_cache: torch.Tensor,
#     block_table: Optional[torch.Tensor],
#     cache_seqlens: Optional[torch.Tensor],
#     head_dim_v: int,
#     tile_scheduler_metadata: FlashMLASchedMeta,
#     num_splits: None = None,
#     softmax_scale: Optional[float] = None,
#     causal: bool = False,
#     descale_q: Optional[torch.Tensor] = None,
#     descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     """
#     Arguments:
#         q: (batch_size, seq_len_q, num_heads_q, head_dim).
#         k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
#         block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
#         cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
#         head_dim_v: Head_dim of v. Must be 512
#         sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
#         num_splits_placeholder: must be "None" (to be compatible with the old interface).
#         softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
#         causal: bool. Whether to apply causal attention mask. Only valid for dense attention
#         descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
#         descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
#     Return:
#         out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
#         softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
#     """
#     sched_meta = tile_scheduler_metadata
#     assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
#     assert num_splits is None, "num_splits must be None"

#     if softmax_scale is None:
#         softmax_scale = q.shape[-1] ** (-0.5)

#     if not sched_meta.have_initialized:
#         # Initialize the tile scheduler metadata during the first invocation.
#         sched_meta.have_initialized = True
#         sched_meta.config = FlashMLASchedMeta.Config(
#             q.shape[0],
#             q.shape[1],
#             q.shape[2],
#             k_cache.shape[1],
#             k_cache.shape[2],
#             causal,
#             False,
#             0,
#             0,
#             0
#         )
#     else:
#         # Check whether the input arguments are consistent with sched_meta
#         helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
#         assert sched_meta.config is not None
#         assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
#         assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
#         assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
#         assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
#         assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
#         assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
#         assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
 

#     # Dense attention
#     assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
#     out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
#         q, k_cache, head_dim_v,
#         cache_seqlens, block_table,
#         softmax_scale, causal,
#         sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
#         descale_q, descale_k
#     )
#     sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
#     sched_meta.num_splits = new_num_splits
shenzhe's avatar
shenzhe committed
649
#     return (out, lse)