flashmla_sparse.py 41.1 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional

import numpy as np
import torch
8
from vllm import envs
9
10

from vllm import _custom_ops as ops
11
from vllm.config import VllmConfig, get_current_vllm_config
12
from vllm.config.cache import CacheDType
13
from vllm.logger import init_logger
14
15
16
17
from vllm.model_executor.layers.attention.mla_attention import (
    MLACommonBaseImpl,
    get_mla_dims,
)
18
from vllm.platforms import current_platform
19
from vllm.platforms.interface import DeviceCapability
20
from vllm.triton_utils import tl, triton
21
22
from vllm.v1.attention.backend import (
    AttentionBackend,
23
    AttentionCGSupport,
24
25
    AttentionLayer,
    AttentionMetadata,
26
27
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
28
29
    MultipleOf,
)
30
from vllm.v1.attention.backends.utils import (
31
32
33
34
    reshape_attn_output_for_spec_decode,
    reshape_query_for_spec_decode,
    split_decodes_and_prefills,
    split_prefill_chunks,
35
)
36
from vllm.v1.attention.ops.flashmla import (
37
38
    FlashMLASchedMeta,
    flash_mla_sparse_fwd,
39
40
41
    flash_mla_with_kvcache,
    get_mla_metadata,
)
42
from vllm.v1.kv_cache_interface import AttentionSpec
43
from vllm.v1.worker.workspace import current_workspace_manager
44
45
46
47
48

if TYPE_CHECKING:
    from vllm.model_executor.models.deepseek_v2 import Indexer

logger = init_logger(__name__)
49
50
51
52
53
54
55
56
57
58
59
60
61

# For FP8 sparse attention we have two impelementations:
# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is
#    done by treating all tokens as single batch.
# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill
#    (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using
#    the FP8 decode kernel for decode.
# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16
# prefill kernel requires padding the numer of heads to 128 while the decode does not
# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed
# batch mode (#2).
MIN_HEADS_FOR_BF16_PREFILL = 32

62
63
64
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format

65
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
66
structured as:
67
-   **First 512 bytes:** The "quantized NoPE" part, containing 512
68
    `float8_e4m3` values.
69
70
-   **Next 16 bytes:** Scale factors, containing 4 `float32` values.
    The first `float32` is the scale for the first 128 `float8_e4m3` values,
71
    the second for the next 128, and so on.
72
-   **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
73
74
75
76
77
78
    part is not quantized for accuracy.
"""


class FlashMLASparseBackend(AttentionBackend):
    accept_output_buffer: bool = True
79
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
80
81
82
83
84
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
        "bfloat16",
        "fp8_ds_mla",
    ]
85

86
87
88
89
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [64]

90
91
    @staticmethod
    def get_name() -> str:
92
        return "FLASHMLA_SPARSE"
93
94
95
96
97
98
99
100
101

    @staticmethod
    def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
        return FlashMLASparseMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["FlashMLASparseImpl"]:
        return FlashMLASparseImpl

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [576]

    @classmethod
    def is_mla(cls) -> bool:
        return True

    @classmethod
    def is_sparse(cls) -> bool:
        return True

    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability.major in [9, 10]

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        if cache_dtype_str == "fp8_ds_mla":
            # custom storage fromat is 656 bytes
            #  see FlashMLA readme.md for details
            return (num_blocks, block_size, 656)
        else:
            return (num_blocks, block_size, head_size)


@dataclass
135
class FlashMLASparseMetadata(AttentionMetadata):
136
137
138
139
140
    num_reqs: int
    max_query_len: int
    max_seq_len: int

    num_actual_tokens: int  # Number of tokens excluding padding.
141
    num_kv_actual_tokens: int
142
143
144
145
146
147
148
149
150
151
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

    block_table: torch.Tensor
    req_id_per_token: torch.Tensor
    block_size: int = 64
    topk_tokens: int = 2048

    @dataclass
    class FP8KernelMetadata:
152
        scheduler_metadata: FlashMLASchedMeta
153
154
155
        dummy_block_table: torch.Tensor
        cache_lens: torch.Tensor

156
    @dataclass
157
    class FP8SeparatePrefillDecode:
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        @dataclass
        class Decode:
            kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
            decode_query_len: int  # needed for reshape in spec decode

        @dataclass
        class Prefill:
            # Sequence lengths (context + query) for prefill requests
            # Shape: [num_prefill_reqs]
            seq_lens: torch.Tensor

            # Request ID for each token: -1 for decode tokens, request index
            # (0, 1, 2, ...) for prefill tokens.
            # Shape: [num_actual_tokens]
            request_ids: torch.Tensor

            # Workspace start offsets for all prefill requests
            # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
            # 0-indexed within each chunk. Used to map prefill tokens to workspace
            # offsets in convert_logical_index_to_physical_index
            workspace_starts: torch.Tensor

            @dataclass
            class Chunk:
                """Metadata for a chunk of prefill requests.

                Prefill requests may be chunked to fit within the fixed workspace size.
                """

                seq_lens: torch.Tensor
                tokens_slice: slice
                block_table: torch.Tensor
                req_start_idx: int
                workspace_starts: torch.Tensor
                chunk_tot_seqlen: int

            chunks: list[Chunk]

        num_prefills: int = 0
        num_decodes: int = 0
        num_prefill_tokens: int = 0
        num_decode_tokens: int = 0

        decode: Decode | None = None
        prefill: Prefill | None = None

204
    fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
205
206
207
208
    fp8_use_mixed_batch: bool = False


# Kernel with prefill workspace support
209
210
211
212
213
214
@triton.jit
def _convert_req_index_to_global_index_kernel(
    req_id_ptr,  # int32 [num_tokens]
    block_table_ptr,  # int32 [num_requests, max_num_blocks_per_req]
    token_indices_ptr,  # int32 [num_tokens, NUM_TOPK_TOKENS]
    out_ptr,  # int32 [num_tokens, NUM_TOPK_TOKENS]
215
216
    prefill_request_id_ptr,  # int32 [num_tokens], -1 for decode, >=0 for prefill
    workspace_starts_ptr,  # int32 [num_prefill_reqs+1] or nullptr
217
218
219
220
    # shapes (compile-time where possible)
    max_num_blocks_per_req: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    BLOCK_N: tl.constexpr,  # tile width along columns
221
    HAS_PREFILL: tl.constexpr,
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    # strides (in elements)
    bt_stride0,
    bt_stride1,
    ti_stride0,
    ti_stride1,
    out_stride0,
    out_stride1,
):
    # program_id(0) -> token_id (row)
    # program_id(1) -> tile index along columns
    token_id = tl.program_id(0)
    tile_id = tl.program_id(1)

    # Each program covers BLOCK_N consecutive columns
    indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)

    # Load request id for this token (no mask: grid is exact)
    req = tl.load(req_id_ptr + token_id)

    # Load token indices for this tile
    ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
    tok = tl.load(ti_ptr)  # int32

    # Only token == -1 should propagate as -1
    is_invalid_tok = tok < 0
247
248
249
250
    is_prefill = False
    if HAS_PREFILL:
        prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
        is_prefill = prefill_req_id >= 0
251
252
253
254
255
    # Compute block id and in-block offset
    block_id = tok // BLOCK_SIZE
    inblock_off = tok % BLOCK_SIZE

    # Guard block_table access
256
    valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
257
    bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
258
259
260
261
262
263
264
265
266
267
268
269
    is_invalid_tok |= ~valid_block
    base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
    out_val = base * BLOCK_SIZE + inblock_off

    # Override with prefill output if prefill is enabled
    if HAS_PREFILL:
        workspace_start = tl.load(
            workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
        )
        prefill_out = workspace_start + tok
        out_val = tl.where(is_prefill, prefill_out, out_val)
    out_val = tl.where(is_invalid_tok, -1, out_val)
270
271
272
273
274
275
276

    # Store results
    out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
    tl.store(out_ptr_ij, out_val)


def triton_convert_req_index_to_global_index(
277
278
279
280
281
282
    req_id: torch.Tensor,  # int32 [num_tokens]
    block_table: torch.Tensor,  # int32 [num_requests, max_num_blocks_per_req]
    token_indices: torch.Tensor,  # int32 [num_tokens, NUM_TOPK_TOKENS]
    BLOCK_SIZE: int = 64,
    NUM_TOPK_TOKENS: int = 2048,
    BLOCK_N: int = 128,  # tile width along columns
283
284
285
    HAS_PREFILL_WORKSPACE: bool = False,
    prefill_workspace_request_ids: torch.Tensor | None = None,
    prefill_workspace_starts: torch.Tensor | None = None,
286
287
288
):
    """
    out[token_id, indice_id] =
289
        block_table[req_id[token_id],
290
291
292
293
            token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
        + token_indices[token_id, indice_id] % BLOCK_SIZE

    Only when token_indices[token_id, indice_id] == -1 do we output -1.
294
    For safety, we also output -1 if the derived block_id would be
295
        out-of-bounds.
296
297
298
299
300
301
302
303
304

    When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
    instead of global cache slots. prefill_workspace_request_ids and
    prefill_workspace_starts must be provided.

    prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
        prefill request index (maps to prefill_workspace_starts)
    prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
        starts for each prefill request
305
    """
306
307
308
309
310
311
312
313
314
315
316
317
    if (envs.USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX):
        from lightop import op
        return op.convert_req_index_to_global_index(
            req_id, 
            block_table, 
            token_indices, 
            BLOCK_SIZE, 
            NUM_TOPK_TOKENS,
            HAS_PREFILL_WORKSPACE, 
            prefill_workspace_request_ids, 
            prefill_workspace_starts
        )
318
319
320
321
    assert req_id.dtype == torch.int32
    assert block_table.dtype == torch.int32
    assert token_indices.dtype == torch.int32
    assert token_indices.shape[1] == NUM_TOPK_TOKENS
322
    assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
323
        f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
324
    )
325

326
327
328
329
330
331
    if HAS_PREFILL_WORKSPACE:
        assert prefill_workspace_request_ids is not None
        assert prefill_workspace_starts is not None
        assert prefill_workspace_request_ids.dtype == torch.int32
        assert prefill_workspace_starts.dtype == torch.int32

332
    num_tokens = req_id.shape[0]
333
    max_num_blocks_per_req = block_table.shape[1]
334
335
336
337
338
339
340
341
342
343
344
345
346
    tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N

    # Ensure contiguous tensors on the same device
    req_id_c = req_id.contiguous()
    block_table_c = block_table.contiguous()
    token_indices_c = token_indices.contiguous()
    out = torch.empty_like(token_indices_c)

    # Strides in elements
    bt_stride0, bt_stride1 = block_table_c.stride()
    ti_stride0, ti_stride1 = token_indices_c.stride()
    out_stride0, out_stride1 = out.stride()

347
348
349
350
351
352
353
    # Prepare prefill pointers
    if HAS_PREFILL_WORKSPACE:
        assert prefill_workspace_request_ids is not None  # for mypy
        assert prefill_workspace_starts is not None  # for mypy
        assert prefill_workspace_request_ids.is_contiguous()
        assert prefill_workspace_starts.is_contiguous()

354
355
356
357
358
359
360
361
    # Exact 2D grid: tokens × column tiles
    grid = (num_tokens, tiles_per_row)

    _convert_req_index_to_global_index_kernel[grid](
        req_id_c,
        block_table_c,
        token_indices_c,
        out,
362
363
        prefill_workspace_request_ids,
        prefill_workspace_starts,
364
365
366
367
        # shapes / constexprs
        max_num_blocks_per_req,
        BLOCK_SIZE,
        BLOCK_N,
368
        HAS_PREFILL_WORKSPACE,
369
370
371
372
373
374
375
376
377
378
379
        # strides
        bt_stride0,
        bt_stride1,
        ti_stride0,
        ti_stride1,
        out_stride0,
        out_stride1,
    )
    return out


380
381
382
383
384
385
386
387
388
389
def get_prefill_workspace_size(max_model_len: int):
    # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
    # May be tuned later.
    # Memory usage: 5 * max_model_len * 576 * 2 bytes
    #   Example: DeepSeek-V3.2 with max_model_len=163840 ->
    #            5 * 163840 * 576 * 2 = ~900 MB
    # This fits nicely below the typical MoE workspace size of >2GB so this is "free"
    return max_model_len * 5


390
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
391
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
392

393
394
395
396
397
398
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
399
400
401
    ) -> None:
        self.vllm_config = vllm_config
        self.layer_names = layer_names
402
403
404
405
406
407
        cache_config = vllm_config.cache_config
        self.kv_cache_spec = kv_cache_spec
        self.model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
        self.device = device

408
409
410
411
        # Treat requests with query length <= 1 as decodes to match the
        # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)

412
413
414
        props = torch.cuda.get_device_properties(device)
        sm_count = props.multi_processor_count

415
        self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
416
        self.mla_dims = get_mla_dims(self.model_config)
Lucas Wilkinson's avatar
Lucas Wilkinson committed
417
418
419
420
        # FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
        self.fp8_decode_padded_heads = (
            FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
        )
421

422
423
        self.topk_tokens = vllm_config.model_config.hf_config.index_topk
        self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
424
425
426
427
        max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
        self.topk_tokens_tensor = torch.full(
            (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
428
        )
429
430
431
432
433
434
        # Shape: [max_num_seqs], all elements = max_model_len
        self.max_model_len_tensor = torch.full(
            (max_num_seqs,),
            self.model_config.max_model_len,
            device=device,
            dtype=torch.int32,
435
        )
436
        # this is ignored by `flash_mla_with_kvcache` if indices not None
437
        self.dummy_block_table = torch.empty(
438
            (max_num_seqs, 1), dtype=torch.int32, device=self.device
439
        )
440

Lucas Wilkinson's avatar
Lucas Wilkinson committed
441
442
443
444
445
446
447
448
        # Equation taken from FlashMLA/csrc/api/sparse_decode.h
        # For sparse FP8 decode, the formula depends on architecture:
        # - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
        # - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
        # - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
        # For max buffer size, use s_q = 1 (the case that produces largest output)
        # Use padded head count since that's what will be passed to the kernel
        h_q = self.fp8_decode_padded_heads
449
        if current_platform.is_device_capability_family(100):
Lucas Wilkinson's avatar
Lucas Wilkinson committed
450
451
452
453
454
            # SM100 head64 or head64x2 uses full SM count
            max_num_sm_parts = sm_count
        else:
            # SM90 uses h_q/64 divisor
            max_num_sm_parts = sm_count // max(1, h_q // 64)
455
456
457
458
459
        self.tile_scheduler_metadata_buffer = torch.empty(
            # TileSchedulerMetaDataSize = 8
            # see: FlashMLA/csrc/params.h
            (max_num_sm_parts, 8),
            dtype=torch.int32,
460
461
            device=device,
        )
462
        # Sized for per-request batching (num_decodes + 1)
463
        self.num_splits_buffer = torch.empty(
464
            (max_num_seqs + 1,),
465
            dtype=torch.int32,
466
467
            device=device,
        )
468
        self.req_id_per_token_buffer = torch.empty(
469
            (vllm_config.scheduler_config.max_num_batched_tokens,),
470
            dtype=torch.int32,
471
472
            device=device,
        )
473
474
475
476
477
478
        self.req_id_per_token_buffer_cpu = torch.zeros((vllm_config.scheduler_config.max_num_batched_tokens,), 
                                                       dtype=torch.int32, 
                                                       device="cpu", 
                                                       pin_memory=True)
        self.req_id_per_token_buffer_np = self.req_id_per_token_buffer_cpu.numpy()

479

480
    def _build_fp8_mixed_decode_prefill(
481
482
        self,
        common_attn_metadata: CommonAttentionMetadata,
483
484
485
486
487
488
    ) -> "FlashMLASparseMetadata.FP8KernelMetadata":
        """Build FP8 metadata treating all tokens as one mixed batch.

        This matches main branch's approach and avoids the BF16 prefill kernel
        which has head padding overhead when num_heads is small (high TP case).
        """
489
        num_tokens = common_attn_metadata.num_actual_tokens
490

Lucas Wilkinson's avatar
Lucas Wilkinson committed
491
492
493
        # Use padded head count since that's what the kernel will see
        padded_heads = self.fp8_decode_padded_heads

494
        # Build metadata for all tokens as a single batch
495
        scheduler_metadata, _ = get_mla_metadata(
496
            cache_seqlens=self.topk_tokens_tensor[:1],  # Single batch
Lucas Wilkinson's avatar
Lucas Wilkinson committed
497
            num_q_tokens_per_head_k=num_tokens * padded_heads,
498
            topk=self.topk_tokens,
Lucas Wilkinson's avatar
Lucas Wilkinson committed
499
            num_heads_q=padded_heads,
500
501
            num_heads_k=1,
            is_fp8_kvcache=True,
502
        )
503
504

        fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
505
            scheduler_metadata=scheduler_metadata,
506
507
            cache_lens=self.max_model_len_tensor[:1],
            dummy_block_table=self.dummy_block_table[:1],
508
        )
509

510
511
512
513
514
        return fp8_metadata

    def _build_fp8_separate_prefill_decode(
        self,
        common_attn_metadata: CommonAttentionMetadata,
515
    ) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
516
517
518
519
520
521
522
523
524
525
        num_tokens = common_attn_metadata.num_actual_tokens

        (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold or 1,
                require_uniform=True,
            )
        )

526
        FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        fp8_metadata = FP8Meta(
            num_decodes=num_decodes,
            num_prefills=num_prefills,
            num_decode_tokens=num_decode_tokens,
            num_prefill_tokens=num_prefill_tokens,
        )

        # Extract prefill sequence lengths (context + query, not just query)
        # Decode requests come first in the batch, prefill requests follow
        prefill_seq_lens = None
        prefill_request_id = None
        prefill_workspace_starts = None
        prefill_chunks = None

        # For pure decode batches, prefill_request_id will be None
        # For mixed batches, it will have -1 for decode and request_id for prefill
        if num_prefills > 0:
544
            seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
            seq_lens = common_attn_metadata.seq_lens
            query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

            prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
            prefill_seq_lens = seq_lens[num_decodes:]

            # Build prefill_request_id: -1 for decode, request index for
            # prefill. This enables a single
            # convert_logical_index_to_physical_index call for all tokens
            prefill_request_id = torch.full(
                (num_tokens,), -1, dtype=torch.int32, device=self.device
            )
            # Map prefill tokens to their request IDs (0, 1, 2, ...)
            for req_idx in range(num_prefills):
                # Get query token range for this prefill request
                global_req_idx = num_decodes + req_idx
                req_query_start = query_start_loc_cpu[global_req_idx]
                req_query_end = query_start_loc_cpu[global_req_idx + 1]
                prefill_request_id[req_query_start:req_query_end] = req_idx

            # will be adjusted by chunk loop
            prefill_workspace_starts_cpu = torch.zeros(
                num_prefills, dtype=torch.int32, pin_memory=True
            )
            prefill_workspace_starts_cpu[1:] = torch.cumsum(
                prefill_seq_lens_cpu[:-1], dim=0
            )
            # populated by non-blocking copy after prefill_workspace_starts_cpu is
            # updated by each chunk
            prefill_workspace_starts = torch.empty(
                num_prefills, dtype=torch.int32, device=self.device
            )

            # Chunk prefill requests to fit within workspace size
            max_prefill_buffer_size = get_prefill_workspace_size(
                self.vllm_config.model_config.max_model_len
            )
            chunk_bounds = split_prefill_chunks(
                prefill_seq_lens_cpu, max_prefill_buffer_size
            )

            prefill_chunks = []
            for chunk_start, chunk_end in chunk_bounds:
                # Adjust workspace_starts in-place per chunk to be
                # 0-indexed within each chunk
                # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
                #   Initial: workspace_starts=[0,10,25,45]
                #   After:   workspace_starts=[0,10,0,20]
                #           (chunk 0 starts at 0, chunk 1 starts at 0)
                offset = prefill_workspace_starts_cpu[chunk_start].item()
                prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset

                chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
                chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
                token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
                token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
                tokens_slice = slice(token_start, token_end)

                # Create chunk view of gpu tensor
                chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
                chunk_block_table = common_attn_metadata.block_table_tensor[
                    num_decodes + chunk_start : num_decodes + chunk_end
                ]

                prefill_chunks.append(
                    FP8Meta.Prefill.Chunk(
                        seq_lens=chunk_seq_lens,
                        tokens_slice=tokens_slice,
                        block_table=chunk_block_table,
                        req_start_idx=chunk_start,
                        workspace_starts=chunk_workspace_starts,
                        chunk_tot_seqlen=chunk_tot_seqlen,
                    )
                )

            prefill_workspace_starts.copy_(
                prefill_workspace_starts_cpu, non_blocking=True
            )

            fp8_metadata.prefill = FP8Meta.Prefill(
                seq_lens=prefill_seq_lens,
                request_ids=prefill_request_id,
                workspace_starts=prefill_workspace_starts,
                chunks=prefill_chunks,
            )

        if num_decodes > 0:
            # Compute decode_query_len for spec decode (uniform due to require_uniform)
            query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
            decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()

Lucas Wilkinson's avatar
Lucas Wilkinson committed
636
637
            # Use padded head count since that's what the kernel will see
            padded_heads = self.fp8_decode_padded_heads
638
            scheduler_metadata, _ = get_mla_metadata(
639
                cache_seqlens=self.topk_tokens_tensor[:num_decodes],
Lucas Wilkinson's avatar
Lucas Wilkinson committed
640
                num_q_tokens_per_head_k=decode_query_len * padded_heads,
641
                topk=self.topk_tokens,
Lucas Wilkinson's avatar
Lucas Wilkinson committed
642
                num_heads_q=padded_heads,
643
644
645
646
                num_heads_k=1,
                is_fp8_kvcache=True,
            )

647
            kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
648
                scheduler_metadata=scheduler_metadata,
649
650
651
652
653
654
                dummy_block_table=self.dummy_block_table[:num_decodes],
                cache_lens=self.max_model_len_tensor[:num_decodes],
            )
            fp8_metadata.decode = FP8Meta.Decode(
                kernel_metadata=kernel_meta,
                decode_query_len=decode_query_len,
655
            )
656

657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
        return fp8_metadata

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashMLASparseMetadata:
        cm = common_attn_metadata
        num_tokens = cm.num_actual_tokens
        starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
        seg_lengths = np.diff(starts)
        req_id_per_token = np.repeat(
            np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
        )
        # Zero-fill for cudagraphs
        self.req_id_per_token_buffer.fill_(0)
674
        self.req_id_per_token_buffer_np[: req_id_per_token.shape[0]] = req_id_per_token
675
        self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
676
677
            self.req_id_per_token_buffer_cpu[: req_id_per_token.shape[0]], non_blocking=True)

678
679
680
        req_id_per_token = self.req_id_per_token_buffer[:num_tokens]

        fp8_extra_metadata: (
681
            FlashMLASparseMetadata.FP8SeparatePrefillDecode
682
683
684
            | FlashMLASparseMetadata.FP8KernelMetadata
            | None
        ) = None
685
        fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and envs.FP8_USE_MIXED_BATCH
686
687
688
689
690
691
        if self.use_fp8_kv_cache:
            if fp8_use_mixed_batch:
                fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
            else:
                fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)

692
        metadata = FlashMLASparseMetadata(
693
694
695
696
            num_reqs=cm.num_reqs,
            max_query_len=cm.max_query_len,
            max_seq_len=cm.max_seq_len,
            num_actual_tokens=cm.num_actual_tokens,
697
            num_kv_actual_tokens=cm.num_kv_actual_tokens,
698
699
700
            query_start_loc=cm.query_start_loc,
            slot_mapping=cm.slot_mapping,
            block_table=cm.block_table_tensor,
701
702
703
704
            req_id_per_token=req_id_per_token,
            block_size=self.kv_cache_spec.block_size,
            topk_tokens=self.topk_tokens,
            fp8_extra_metadata=fp8_extra_metadata,
705
            fp8_use_mixed_batch=fp8_use_mixed_batch,
706
        )
707

708
709
710
711
        return metadata


class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
Lucas Wilkinson's avatar
Lucas Wilkinson committed
712
713
714
715
716
717
    @staticmethod
    def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
        # FP8 decode kernel only supports h_q = 64 or 128
        # Compute padded head count for decode
        return 64 if num_heads <= 64 else 128

718
    def __init__(
719
720
721
722
723
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
724
725
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
726
        kv_cache_dtype: str,
727
        logits_soft_cap: float | None,
728
        attn_type: str,
729
        kv_sharing_target_layer_name: str | None,
730
        # MLA Specific Arguments
731
        topk_indice_buffer: torch.Tensor | None = None,
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
        indexer: Optional["Indexer"] = None,
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )
748
749
        self.softmax_scale = scale
        assert indexer is not None
750
        self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
Lucas Wilkinson's avatar
Lucas Wilkinson committed
751
752
753
754
755
        # Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
        self.prefill_padding = (
            128 if current_platform.is_device_capability_family(100) else 64
        )
        self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
756

757
758
759
760
761
762
763
764
765
766
767
768
769
770
        if kv_cache_dtype == "fp8_ds_mla":
            # Reserve workspace during initialization
            vllm_config = get_current_vllm_config()
            assert vllm_config is not None and vllm_config.model_config is not None
            prefill_workspace_size = get_prefill_workspace_size(
                vllm_config.model_config.max_model_len
            )
            self.prefill_workspace_shape = (prefill_workspace_size, head_size)
            (self.prefill_bf16_workspace,) = (
                current_workspace_manager().get_simultaneous(
                    (self.prefill_workspace_shape, torch.bfloat16)
                )
            )

771
    def _forward_bf16_kv(
772
773
774
775
776
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
    ) -> torch.Tensor:
        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
        )

        return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)

    def _forward_fp8_kv_separate_prefill_decode(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        fp8_metadata = attn_metadata.fp8_extra_metadata
798
        assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
        num_decodes = fp8_metadata.num_decodes

        prefill_request_ids = None
        prefill_workspace_starts = None
        has_prefill_workspace = False
        if fp8_metadata.prefill is not None:
            prefill_request_ids = fp8_metadata.prefill.request_ids
            prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
            has_prefill_workspace = True

        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        # For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
        # For BF16 cache: always use global cache slots (no workspace)
        # prefill_workspace_starts has been adjusted in-place per chunk so
        # prefill indices automatically come out chunk-local
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
            HAS_PREFILL_WORKSPACE=has_prefill_workspace,
            prefill_workspace_request_ids=prefill_request_ids,
            prefill_workspace_starts=prefill_workspace_starts,
        )

        fp8_metadata = attn_metadata.fp8_extra_metadata
827
        assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935

        def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
            # Reshape q: (num_decode_tokens, num_heads, head_dim)
            #         -> (num_decodes, seq_len, num_heads, head_dim)
            q = reshape_query_for_spec_decode(q, num_decodes)
            seq_len = q.shape[1]
            # Reshape topk_indices: (num_decode_tokens, topk)
            #                    -> (num_decodes, seq_len, topk)
            topk_indices = topk_indices.view(num_decodes, seq_len, -1)
            assert fp8_metadata.decode is not None
            attn_out, _ = self._fp8_flash_mla_kernel(
                q=q,
                kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
                topk_indices=topk_indices,
                kernel_metadata=fp8_metadata.decode.kernel_metadata,
            )
            # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
            #              -> (num_decode_tokens, num_heads, head_dim_v)
            return reshape_attn_output_for_spec_decode(attn_out)

        num_decode_tokens = fp8_metadata.num_decode_tokens
        num_prefill_tokens = fp8_metadata.num_prefill_tokens

        # Pure decode: direct call without allocation
        if num_decode_tokens > 0 and num_prefill_tokens == 0:
            assert fp8_metadata.decode is not None
            attn_out = _fp8_decode(q, topk_indices)
        else:
            # Mixed or pure prefill: allocate output tensor
            attn_out = q.new_empty(
                (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
                dtype=q.dtype,
                device=q.device,
            )

            if num_decode_tokens > 0:
                attn_out[:num_decode_tokens] = _fp8_decode(
                    q[:num_decode_tokens], topk_indices[:num_decode_tokens]
                )

            assert fp8_metadata.prefill is not None
            for chunk in fp8_metadata.prefill.chunks:
                chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
                ops.cp_gather_and_upconvert_fp8_kv_cache(
                    kv_c_and_k_pe_cache,
                    chunk_workspace,
                    chunk.block_table,
                    chunk.seq_lens,
                    chunk.workspace_starts,
                    len(chunk.block_table),
                )

                chunk_q = q[chunk.tokens_slice]
                chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]

                attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
                    chunk_q,
                    chunk_workspace,
                    chunk_topk_indices_workspace,
                )

        return attn_out

    def _forward_fp8_kv_mixed_batch(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        """Mixed batch FP8 forward path that treats all tokens as one batch.

        This is equivalent to main branch's approach and avoids the BF16
        prefill kernel which has head padding overhead when num_heads is small.
        Used when use_mixed_batch is True.
        """
        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
        )

        assert attn_metadata.fp8_extra_metadata is not None
        assert isinstance(
            attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
        )
        fp8_metadata = attn_metadata.fp8_extra_metadata

        _attn_out, _ = self._fp8_flash_mla_kernel(
            q=q.unsqueeze(0),  # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
            kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
            topk_indices=topk_indices.unsqueeze(0),  # (T, topk) -> (1, T, topk)
            kernel_metadata=fp8_metadata,
        )

        # Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
        return _attn_out.squeeze(0)

    def _fp8_flash_mla_kernel(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
Lucas Wilkinson's avatar
Lucas Wilkinson committed
936
937
938
939
940
941
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # q shape: (batch, seq_len, num_heads, head_dim)
        actual_num_heads = q.size(2)
        padded_num_heads = self.fp8_decode_padded_heads

        # Pad query if needed (kernel only supports h_q = 64 or 128)
942
943
944
945
946
947
948
949
        # if actual_num_heads < padded_num_heads:
        #     logger.warning_once(
        #         f"Padding num_heads from {actual_num_heads} to "
        #         f"{padded_num_heads} for FP8 sparse decode kernel"
        #     )
        #     q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
        #     q_padded[:, :, :actual_num_heads, :] = q
        #     q = q_padded     
Lucas Wilkinson's avatar
Lucas Wilkinson committed
950
951

        out, lse = flash_mla_with_kvcache(
952
953
954
955
956
957
958
959
960
961
962
            q=q,
            k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
            block_table=kernel_metadata.dummy_block_table,
            head_dim_v=512,
            cache_seqlens=kernel_metadata.cache_lens,
            tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
            is_fp8_kvcache=True,
            indices=topk_indices,
            softmax_scale=self.softmax_scale,
        )

Lucas Wilkinson's avatar
Lucas Wilkinson committed
963
        # Slice output back to actual head count if we padded
liuchy5's avatar
liuchy5 committed
964
965
        #if actual_num_heads < padded_num_heads:
        #    out = out[:, :, :actual_num_heads, :]
Lucas Wilkinson's avatar
Lucas Wilkinson committed
966
967
968

        return out, lse

969
970
971
972
973
    def _bf16_flash_mla_kernel(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
974
    ) -> torch.Tensor:
975
976
        num_tokens = q.shape[0]
        kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
977
978
            -1, 1, kv_c_and_k_pe_cache.shape[-1]
        )
979
980
981

        # NOTE(Chen): kernel requires num_local_head to be a multiple of
        # 64 on hopper and 128 on blackwell
982
983
984
985
986
987
988
989
990
        # if self.num_heads % self.prefill_padding != 0:
        #     assert self.prefill_padding % self.num_heads == 0
        #     logger.warning_once(
        #         f"Padding num_heads from {self.num_heads} to "
        #         f"{self.prefill_padding} for BF16 sparse prefill kernel"
        #     )
        #     q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
        #     q_padded[:, : self.num_heads, :] = q
        #     q = q_padded
991
992

        topk_indices = topk_indices.view(num_tokens, 1, -1)
993
        output = flash_mla_sparse_fwd(
994
995
996
            q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
        )[0]
        output = output[:, : self.num_heads, :]
997
998
999
1000
1001
1002
1003
1004
1005
        return output

    def forward(
        self,
        layer: AttentionLayer,
        q: torch.Tensor,
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
1006
        attn_metadata: FlashMLASparseMetadata | None,
1007
1008
1009
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
1010
1011
1012
1013
1014
1015
1016
1017
    ) -> torch.Tensor:
        # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
        # MQA 576/512 approach for both prefill and decode

        assert output is not None, "Output tensor must be provided."

        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
1018
1019
                "fused output quantization is not yet supported for MLACommonImpl"
            )
1020
1021

        if attn_metadata is None:
1022
            # Dummy run - no need to allocate buffers
1023
1024
1025
1026
1027
1028
            # The zero fill is required when used with DP + EP
            # to ensure all ranks within a DP group compute the
            # same expert outputs.
            return output.fill_(0)

        num_actual_toks = attn_metadata.num_actual_tokens
1029
        num_kv_actual_toks = attn_metadata.num_kv_actual_tokens
1030
1031
1032
1033

        # Inputs and outputs may be padded for CUDA graphs

        q = q[:num_actual_toks, ...]
1034
1035
        k_c_normed = k_c_normed[:num_kv_actual_toks, ...]
        k_pe = k_pe[:num_kv_actual_toks, ...]
1036
        assert self.topk_indices_buffer is not None
1037
        topk_indices = self.topk_indices_buffer[:num_actual_toks]
1038

1039
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1040
1041
1042
1043
1044
1045
1046
        # Convert from (B, N, P) to (N, B, P)
        q_nope = q_nope.transpose(0, 1)
        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
        ql_nope = torch.bmm(q_nope, self.W_UK_T)
        # Convert from (N, B, L) to (B, N, L)
        ql_nope = ql_nope.transpose(0, 1)

1047
        use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061

        q = torch.cat([ql_nope, q_pe], dim=-1)

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                k_c_normed,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=self.kv_cache_dtype,
                scale=layer._k_scale,
            )

1062
1063
1064
1065
1066
        if not use_fp8_cache:
            attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata)
        elif attn_metadata.fp8_use_mixed_batch:
            attn_out = self._forward_fp8_kv_mixed_batch(
                q, kv_cache, topk_indices, attn_metadata
1067
            )
1068
        else:
1069
1070
            attn_out = self._forward_fp8_kv_separate_prefill_decode(
                q, kv_cache, topk_indices, attn_metadata
1071
            )
1072
1073
1074

        self._v_up_proj(attn_out, out=output[:num_actual_toks])
        return output