"tests/models/quantization/test_bitsandbytes.py" did not exist on "1b15df2546e97c409668da92954d8802c48d13af"
flashmla_sparse.py 33.6 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, ClassVar
5
6
7
8
9

import numpy as np
import torch

from vllm import _custom_ops as ops
10
from vllm.config import VllmConfig, get_current_vllm_config
11
from vllm.config.cache import CacheDType
12
from vllm.logger import init_logger
13
14
15
from vllm.model_executor.layers.attention.mla_attention import (
    get_mla_dims,
)
16
from vllm.platforms import current_platform
17
from vllm.platforms.interface import DeviceCapability
18
from vllm.utils.platform_utils import num_compute_units
19
from vllm.utils.torch_utils import is_quantized_kv_cache
20
21
from vllm.v1.attention.backend import (
    AttentionBackend,
22
    AttentionCGSupport,
23
24
    AttentionLayer,
    AttentionMetadata,
25
26
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
27
    MultipleOf,
28
    SparseMLAAttentionImpl,
29
)
30
31
32
from vllm.v1.attention.backends.mla.sparse_utils import (
    triton_convert_req_index_to_global_index,
)
33
from vllm.v1.attention.backends.utils import (
34
35
36
37
    reshape_attn_output_for_spec_decode,
    reshape_query_for_spec_decode,
    split_decodes_and_prefills,
    split_prefill_chunks,
38
)
39
from vllm.v1.attention.ops.flashmla import (
40
41
    FlashMLASchedMeta,
    flash_mla_sparse_fwd,
42
43
44
    flash_mla_with_kvcache,
    get_mla_metadata,
)
45
from vllm.v1.kv_cache_interface import AttentionSpec
46
from vllm.v1.worker.workspace import current_workspace_manager
47
48
49
50
51

if TYPE_CHECKING:
    from vllm.model_executor.models.deepseek_v2 import Indexer

logger = init_logger(__name__)
52

Jiayi Yan's avatar
Jiayi Yan committed
53
# For FP8 sparse attention we have two implementations:
54
55
56
57
58
59
# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is
#    done by treating all tokens as single batch.
# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill
#    (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using
#    the FP8 decode kernel for decode.
# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16
Jiayi Yan's avatar
Jiayi Yan committed
60
# prefill kernel requires padding the number of heads to 128 while the decode does not
61
62
63
64
# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed
# batch mode (#2).
MIN_HEADS_FOR_BF16_PREFILL = 32

65
66
67
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format

68
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
69
structured as:
70
-   **First 512 bytes:** The "quantized NoPE" part, containing 512
71
    `float8_e4m3` values.
72
73
-   **Next 16 bytes:** Scale factors, containing 4 `float32` values.
    The first `float32` is the scale for the first 128 `float8_e4m3` values,
74
    the second for the next 128, and so on.
75
-   **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
76
77
78
79
80
    part is not quantized for accuracy.
"""


class FlashMLASparseBackend(AttentionBackend):
81
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
82
83
84
85
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
        "bfloat16",
        "fp8_ds_mla",
86
        "fp8",  # alias for fp8_ds_mla
87
    ]
88

89
90
91
92
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [64]

93
94
    @staticmethod
    def get_name() -> str:
95
        return "FLASHMLA_SPARSE"
96
97
98
99
100
101
102
103
104

    @staticmethod
    def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
        return FlashMLASparseMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["FlashMLASparseImpl"]:
        return FlashMLASparseImpl

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [576]

    @classmethod
    def is_mla(cls) -> bool:
        return True

    @classmethod
    def is_sparse(cls) -> bool:
        return True

    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability.major in [9, 10]

121
122
123
124
125
126
127
128
129
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        if cache_dtype_str == "fp8_ds_mla":
Jiayi Yan's avatar
Jiayi Yan committed
130
            # custom storage format is 656 bytes
131
132
133
134
135
136
137
            #  see FlashMLA readme.md for details
            return (num_blocks, block_size, 656)
        else:
            return (num_blocks, block_size, head_size)


@dataclass
138
class FlashMLASparseMetadata(AttentionMetadata):
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    num_reqs: int
    max_query_len: int
    max_seq_len: int

    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

    block_table: torch.Tensor
    req_id_per_token: torch.Tensor
    block_size: int = 64
    topk_tokens: int = 2048

    @dataclass
    class FP8KernelMetadata:
154
        scheduler_metadata: FlashMLASchedMeta
155
156
157
        dummy_block_table: torch.Tensor
        cache_lens: torch.Tensor

158
    @dataclass
159
    class FP8SeparatePrefillDecode:
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        @dataclass
        class Decode:
            kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
            decode_query_len: int  # needed for reshape in spec decode

        @dataclass
        class Prefill:
            # Sequence lengths (context + query) for prefill requests
            # Shape: [num_prefill_reqs]
            seq_lens: torch.Tensor

            # Request ID for each token: -1 for decode tokens, request index
            # (0, 1, 2, ...) for prefill tokens.
            # Shape: [num_actual_tokens]
            request_ids: torch.Tensor

            # Workspace start offsets for all prefill requests
            # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
            # 0-indexed within each chunk. Used to map prefill tokens to workspace
            # offsets in convert_logical_index_to_physical_index
            workspace_starts: torch.Tensor

            @dataclass
            class Chunk:
                """Metadata for a chunk of prefill requests.

                Prefill requests may be chunked to fit within the fixed workspace size.
                """

                seq_lens: torch.Tensor
                tokens_slice: slice
                block_table: torch.Tensor
                req_start_idx: int
                workspace_starts: torch.Tensor
                chunk_tot_seqlen: int

            chunks: list[Chunk]

        num_prefills: int = 0
        num_decodes: int = 0
        num_prefill_tokens: int = 0
        num_decode_tokens: int = 0

        decode: Decode | None = None
        prefill: Prefill | None = None

206
    fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
207
208
209
210
211
212
213
214
215
216
217
218
219
    fp8_use_mixed_batch: bool = False


def get_prefill_workspace_size(max_model_len: int):
    # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
    # May be tuned later.
    # Memory usage: 5 * max_model_len * 576 * 2 bytes
    #   Example: DeepSeek-V3.2 with max_model_len=163840 ->
    #            5 * 163840 * 576 * 2 = ~900 MB
    # This fits nicely below the typical MoE workspace size of >2GB so this is "free"
    return max_model_len * 5


220
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
221
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
222

223
224
225
226
227
228
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
229
230
231
    ) -> None:
        self.vllm_config = vllm_config
        self.layer_names = layer_names
232
233
234
235
236
237
        cache_config = vllm_config.cache_config
        self.kv_cache_spec = kv_cache_spec
        self.model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
        self.device = device

238
239
240
241
        # Treat requests with query length <= 1 as decodes to match the
        # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)

242
        sm_count = num_compute_units(device.index)
243

244
        self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
245
        self.mla_dims = get_mla_dims(self.model_config)
Lucas Wilkinson's avatar
Lucas Wilkinson committed
246
247
248
249
        # FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
        self.fp8_decode_padded_heads = (
            FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
        )
250

251
252
        self.topk_tokens = vllm_config.model_config.hf_config.index_topk
        self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
253
254
255
256
        max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
        self.topk_tokens_tensor = torch.full(
            (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
257
        )
258
259
260
261
262
263
        # Shape: [max_num_seqs], all elements = max_model_len
        self.max_model_len_tensor = torch.full(
            (max_num_seqs,),
            self.model_config.max_model_len,
            device=device,
            dtype=torch.int32,
264
        )
265
        # this is ignored by `flash_mla_with_kvcache` if indices not None
266
        self.dummy_block_table = torch.empty(
267
            (max_num_seqs, 1), dtype=torch.int32, device=self.device
268
        )
269

Lucas Wilkinson's avatar
Lucas Wilkinson committed
270
271
272
273
274
275
276
277
        # Equation taken from FlashMLA/csrc/api/sparse_decode.h
        # For sparse FP8 decode, the formula depends on architecture:
        # - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
        # - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
        # - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
        # For max buffer size, use s_q = 1 (the case that produces largest output)
        # Use padded head count since that's what will be passed to the kernel
        h_q = self.fp8_decode_padded_heads
278
        if current_platform.is_device_capability_family(100):
Lucas Wilkinson's avatar
Lucas Wilkinson committed
279
280
281
282
283
            # SM100 head64 or head64x2 uses full SM count
            max_num_sm_parts = sm_count
        else:
            # SM90 uses h_q/64 divisor
            max_num_sm_parts = sm_count // max(1, h_q // 64)
284
285
286
287
288
        self.tile_scheduler_metadata_buffer = torch.empty(
            # TileSchedulerMetaDataSize = 8
            # see: FlashMLA/csrc/params.h
            (max_num_sm_parts, 8),
            dtype=torch.int32,
289
290
            device=device,
        )
291
        # Sized for per-request batching (num_decodes + 1)
292
        self.num_splits_buffer = torch.empty(
293
            (max_num_seqs + 1,),
294
            dtype=torch.int32,
295
296
            device=device,
        )
297
        self.req_id_per_token_buffer = torch.empty(
298
            (vllm_config.scheduler_config.max_num_batched_tokens,),
299
            dtype=torch.int32,
300
301
            device=device,
        )
302

303
    def _build_fp8_mixed_decode_prefill(
304
305
        self,
        common_attn_metadata: CommonAttentionMetadata,
306
307
308
309
310
311
    ) -> "FlashMLASparseMetadata.FP8KernelMetadata":
        """Build FP8 metadata treating all tokens as one mixed batch.

        This matches main branch's approach and avoids the BF16 prefill kernel
        which has head padding overhead when num_heads is small (high TP case).
        """
312
        num_tokens = common_attn_metadata.num_actual_tokens
313

Lucas Wilkinson's avatar
Lucas Wilkinson committed
314
315
316
        # Use padded head count since that's what the kernel will see
        padded_heads = self.fp8_decode_padded_heads

317
        # Build metadata for all tokens as a single batch
318
        scheduler_metadata, _ = get_mla_metadata(
319
            cache_seqlens=self.topk_tokens_tensor[:1],  # Single batch
Lucas Wilkinson's avatar
Lucas Wilkinson committed
320
            num_q_tokens_per_head_k=num_tokens * padded_heads,
321
            topk=self.topk_tokens,
Lucas Wilkinson's avatar
Lucas Wilkinson committed
322
            num_heads_q=padded_heads,
323
324
            num_heads_k=1,
            is_fp8_kvcache=True,
325
        )
326
327

        fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
328
            scheduler_metadata=scheduler_metadata,
329
330
            cache_lens=self.max_model_len_tensor[:1],
            dummy_block_table=self.dummy_block_table[:1],
331
        )
332

333
334
335
336
337
        return fp8_metadata

    def _build_fp8_separate_prefill_decode(
        self,
        common_attn_metadata: CommonAttentionMetadata,
338
    ) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
339
340
341
342
343
344
345
346
347
348
        num_tokens = common_attn_metadata.num_actual_tokens

        (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold or 1,
                require_uniform=True,
            )
        )

349
        FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        fp8_metadata = FP8Meta(
            num_decodes=num_decodes,
            num_prefills=num_prefills,
            num_decode_tokens=num_decode_tokens,
            num_prefill_tokens=num_prefill_tokens,
        )

        # Extract prefill sequence lengths (context + query, not just query)
        # Decode requests come first in the batch, prefill requests follow
        prefill_seq_lens = None
        prefill_request_id = None
        prefill_workspace_starts = None
        prefill_chunks = None

        # For pure decode batches, prefill_request_id will be None
        # For mixed batches, it will have -1 for decode and request_id for prefill
        if num_prefills > 0:
367
368
369
370
            # Upper bound is exact for prefill rows (the `[num_decodes:]`
            # slice below), so no D2H sync is needed.
            seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
            assert seq_lens_cpu is not None
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
            seq_lens = common_attn_metadata.seq_lens
            query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

            prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
            prefill_seq_lens = seq_lens[num_decodes:]

            # Build prefill_request_id: -1 for decode, request index for
            # prefill. This enables a single
            # convert_logical_index_to_physical_index call for all tokens
            prefill_request_id = torch.full(
                (num_tokens,), -1, dtype=torch.int32, device=self.device
            )
            # Map prefill tokens to their request IDs (0, 1, 2, ...)
            for req_idx in range(num_prefills):
                # Get query token range for this prefill request
                global_req_idx = num_decodes + req_idx
                req_query_start = query_start_loc_cpu[global_req_idx]
                req_query_end = query_start_loc_cpu[global_req_idx + 1]
                prefill_request_id[req_query_start:req_query_end] = req_idx

            # will be adjusted by chunk loop
            prefill_workspace_starts_cpu = torch.zeros(
                num_prefills, dtype=torch.int32, pin_memory=True
            )
            prefill_workspace_starts_cpu[1:] = torch.cumsum(
                prefill_seq_lens_cpu[:-1], dim=0
            )
            # populated by non-blocking copy after prefill_workspace_starts_cpu is
            # updated by each chunk
            prefill_workspace_starts = torch.empty(
                num_prefills, dtype=torch.int32, device=self.device
            )

            # Chunk prefill requests to fit within workspace size
            max_prefill_buffer_size = get_prefill_workspace_size(
                self.vllm_config.model_config.max_model_len
            )
            chunk_bounds = split_prefill_chunks(
                prefill_seq_lens_cpu, max_prefill_buffer_size
            )

            prefill_chunks = []
            for chunk_start, chunk_end in chunk_bounds:
                # Adjust workspace_starts in-place per chunk to be
                # 0-indexed within each chunk
                # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
                #   Initial: workspace_starts=[0,10,25,45]
                #   After:   workspace_starts=[0,10,0,20]
                #           (chunk 0 starts at 0, chunk 1 starts at 0)
                offset = prefill_workspace_starts_cpu[chunk_start].item()
                prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset

                chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
                chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
                token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
                token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
                tokens_slice = slice(token_start, token_end)

                # Create chunk view of gpu tensor
                chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
                chunk_block_table = common_attn_metadata.block_table_tensor[
                    num_decodes + chunk_start : num_decodes + chunk_end
                ]

                prefill_chunks.append(
                    FP8Meta.Prefill.Chunk(
                        seq_lens=chunk_seq_lens,
                        tokens_slice=tokens_slice,
                        block_table=chunk_block_table,
                        req_start_idx=chunk_start,
                        workspace_starts=chunk_workspace_starts,
                        chunk_tot_seqlen=chunk_tot_seqlen,
                    )
                )

            prefill_workspace_starts.copy_(
                prefill_workspace_starts_cpu, non_blocking=True
            )

            fp8_metadata.prefill = FP8Meta.Prefill(
                seq_lens=prefill_seq_lens,
                request_ids=prefill_request_id,
                workspace_starts=prefill_workspace_starts,
                chunks=prefill_chunks,
            )

        if num_decodes > 0:
            # Compute decode_query_len for spec decode (uniform due to require_uniform)
            query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
            decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()

Lucas Wilkinson's avatar
Lucas Wilkinson committed
462
463
            # Use padded head count since that's what the kernel will see
            padded_heads = self.fp8_decode_padded_heads
464
            scheduler_metadata, _ = get_mla_metadata(
465
                cache_seqlens=self.topk_tokens_tensor[:num_decodes],
Lucas Wilkinson's avatar
Lucas Wilkinson committed
466
                num_q_tokens_per_head_k=decode_query_len * padded_heads,
467
                topk=self.topk_tokens,
Lucas Wilkinson's avatar
Lucas Wilkinson committed
468
                num_heads_q=padded_heads,
469
470
471
472
                num_heads_k=1,
                is_fp8_kvcache=True,
            )

473
            kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
474
                scheduler_metadata=scheduler_metadata,
475
476
477
478
479
480
                dummy_block_table=self.dummy_block_table[:num_decodes],
                cache_lens=self.max_model_len_tensor[:num_decodes],
            )
            fp8_metadata.decode = FP8Meta.Decode(
                kernel_metadata=kernel_meta,
                decode_query_len=decode_query_len,
481
            )
482

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        return fp8_metadata

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashMLASparseMetadata:
        cm = common_attn_metadata
        num_tokens = cm.num_actual_tokens
        starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
        seg_lengths = np.diff(starts)
        req_id_per_token = np.repeat(
            np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
        )
        # Zero-fill for cudagraphs
        self.req_id_per_token_buffer.fill_(0)
        self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
            torch.from_numpy(req_id_per_token), non_blocking=True
        )
        req_id_per_token = self.req_id_per_token_buffer[:num_tokens]

        fp8_extra_metadata: (
506
            FlashMLASparseMetadata.FP8SeparatePrefillDecode
507
508
509
510
511
512
513
514
515
516
            | FlashMLASparseMetadata.FP8KernelMetadata
            | None
        ) = None
        fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
        if self.use_fp8_kv_cache:
            if fp8_use_mixed_batch:
                fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
            else:
                fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)

517
        metadata = FlashMLASparseMetadata(
518
519
520
521
522
523
524
            num_reqs=cm.num_reqs,
            max_query_len=cm.max_query_len,
            max_seq_len=cm.max_seq_len,
            num_actual_tokens=cm.num_actual_tokens,
            query_start_loc=cm.query_start_loc,
            slot_mapping=cm.slot_mapping,
            block_table=cm.block_table_tensor,
525
526
527
528
            req_id_per_token=req_id_per_token,
            block_size=self.kv_cache_spec.block_size,
            topk_tokens=self.topk_tokens,
            fp8_extra_metadata=fp8_extra_metadata,
529
            fp8_use_mixed_batch=fp8_use_mixed_batch,
530
        )
531

532
533
534
        return metadata


535
class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
Lucas Wilkinson's avatar
Lucas Wilkinson committed
536
537
538
539
540
541
    @staticmethod
    def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
        # FP8 decode kernel only supports h_q = 64 or 128
        # Compute padded head count for decode
        return 64 if num_heads <= 64 else 128

542
    def __init__(
543
544
545
546
547
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
548
549
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
550
        kv_cache_dtype: str,
551
        logits_soft_cap: float | None,
552
        attn_type: str,
553
        kv_sharing_target_layer_name: str | None,
554
        # MLA Specific Arguments
555
        topk_indice_buffer: torch.Tensor | None = None,
556
        indexer: "Indexer | None" = None,
557
558
        **mla_args,
    ) -> None:
559
560
561
562
563
564
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype
        self.kv_lora_rank: int = mla_args["kv_lora_rank"]
565
566
        self.softmax_scale = scale
        assert indexer is not None
567
        self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
Lucas Wilkinson's avatar
Lucas Wilkinson committed
568
569
570
571
572
        # Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
        self.prefill_padding = (
            128 if current_platform.is_device_capability_family(100) else 64
        )
        self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
573

574
575
576
        vllm_config = get_current_vllm_config()
        max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
        q_concat_shape = (max_tokens, num_heads, head_size)
577
        if is_quantized_kv_cache(kv_cache_dtype):
578
579
580
581
582
            assert kv_cache_dtype == "fp8_ds_mla", (
                "FlashMLA Sparse Attention backend fp8 only supports "
                "fp8_ds_mla kv-cache dtype"
            )

583
584
585
586
587
588
589
        if kv_cache_dtype == "fp8_ds_mla":
            # Reserve workspace during initialization
            assert vllm_config is not None and vllm_config.model_config is not None
            prefill_workspace_size = get_prefill_workspace_size(
                vllm_config.model_config.max_model_len
            )
            self.prefill_workspace_shape = (prefill_workspace_size, head_size)
590
            self.q_concat_buffer, self.prefill_bf16_workspace = (
591
                current_workspace_manager().get_simultaneous(
592
593
                    (q_concat_shape, torch.bfloat16),
                    (self.prefill_workspace_shape, torch.bfloat16),
594
595
                )
            )
596
597
598
599
        else:
            (self.q_concat_buffer,) = current_workspace_manager().get_simultaneous(
                (q_concat_shape, torch.bfloat16),
            )
600

601
    def _forward_bf16_kv(
602
603
604
605
606
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
    ) -> torch.Tensor:
        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
        )

        return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)

    def _forward_fp8_kv_separate_prefill_decode(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        fp8_metadata = attn_metadata.fp8_extra_metadata
628
        assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
        num_decodes = fp8_metadata.num_decodes

        prefill_request_ids = None
        prefill_workspace_starts = None
        has_prefill_workspace = False
        if fp8_metadata.prefill is not None:
            prefill_request_ids = fp8_metadata.prefill.request_ids
            prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
            has_prefill_workspace = True

        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        # For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
        # For BF16 cache: always use global cache slots (no workspace)
        # prefill_workspace_starts has been adjusted in-place per chunk so
        # prefill indices automatically come out chunk-local
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
            HAS_PREFILL_WORKSPACE=has_prefill_workspace,
            prefill_workspace_request_ids=prefill_request_ids,
            prefill_workspace_starts=prefill_workspace_starts,
        )

        fp8_metadata = attn_metadata.fp8_extra_metadata
657
        assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765

        def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
            # Reshape q: (num_decode_tokens, num_heads, head_dim)
            #         -> (num_decodes, seq_len, num_heads, head_dim)
            q = reshape_query_for_spec_decode(q, num_decodes)
            seq_len = q.shape[1]
            # Reshape topk_indices: (num_decode_tokens, topk)
            #                    -> (num_decodes, seq_len, topk)
            topk_indices = topk_indices.view(num_decodes, seq_len, -1)
            assert fp8_metadata.decode is not None
            attn_out, _ = self._fp8_flash_mla_kernel(
                q=q,
                kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
                topk_indices=topk_indices,
                kernel_metadata=fp8_metadata.decode.kernel_metadata,
            )
            # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
            #              -> (num_decode_tokens, num_heads, head_dim_v)
            return reshape_attn_output_for_spec_decode(attn_out)

        num_decode_tokens = fp8_metadata.num_decode_tokens
        num_prefill_tokens = fp8_metadata.num_prefill_tokens

        # Pure decode: direct call without allocation
        if num_decode_tokens > 0 and num_prefill_tokens == 0:
            assert fp8_metadata.decode is not None
            attn_out = _fp8_decode(q, topk_indices)
        else:
            # Mixed or pure prefill: allocate output tensor
            attn_out = q.new_empty(
                (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
                dtype=q.dtype,
                device=q.device,
            )

            if num_decode_tokens > 0:
                attn_out[:num_decode_tokens] = _fp8_decode(
                    q[:num_decode_tokens], topk_indices[:num_decode_tokens]
                )

            assert fp8_metadata.prefill is not None
            for chunk in fp8_metadata.prefill.chunks:
                chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
                ops.cp_gather_and_upconvert_fp8_kv_cache(
                    kv_c_and_k_pe_cache,
                    chunk_workspace,
                    chunk.block_table,
                    chunk.seq_lens,
                    chunk.workspace_starts,
                    len(chunk.block_table),
                )

                chunk_q = q[chunk.tokens_slice]
                chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]

                attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
                    chunk_q,
                    chunk_workspace,
                    chunk_topk_indices_workspace,
                )

        return attn_out

    def _forward_fp8_kv_mixed_batch(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        """Mixed batch FP8 forward path that treats all tokens as one batch.

        This is equivalent to main branch's approach and avoids the BF16
        prefill kernel which has head padding overhead when num_heads is small.
        Used when use_mixed_batch is True.
        """
        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
        )

        assert attn_metadata.fp8_extra_metadata is not None
        assert isinstance(
            attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
        )
        fp8_metadata = attn_metadata.fp8_extra_metadata

        _attn_out, _ = self._fp8_flash_mla_kernel(
            q=q.unsqueeze(0),  # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
            kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
            topk_indices=topk_indices.unsqueeze(0),  # (T, topk) -> (1, T, topk)
            kernel_metadata=fp8_metadata,
        )

        # Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
        return _attn_out.squeeze(0)

    def _fp8_flash_mla_kernel(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
Lucas Wilkinson's avatar
Lucas Wilkinson committed
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # q shape: (batch, seq_len, num_heads, head_dim)
        actual_num_heads = q.size(2)
        padded_num_heads = self.fp8_decode_padded_heads

        # Pad query if needed (kernel only supports h_q = 64 or 128)
        if actual_num_heads < padded_num_heads:
            logger.warning_once(
                f"Padding num_heads from {actual_num_heads} to "
                f"{padded_num_heads} for FP8 sparse decode kernel"
            )
            q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
            q_padded[:, :, :actual_num_heads, :] = q
            q = q_padded

        out, lse = flash_mla_with_kvcache(
782
783
784
785
786
787
788
789
790
791
792
            q=q,
            k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
            block_table=kernel_metadata.dummy_block_table,
            head_dim_v=512,
            cache_seqlens=kernel_metadata.cache_lens,
            tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
            is_fp8_kvcache=True,
            indices=topk_indices,
            softmax_scale=self.softmax_scale,
        )

Lucas Wilkinson's avatar
Lucas Wilkinson committed
793
794
795
796
797
798
        # Slice output back to actual head count if we padded
        if actual_num_heads < padded_num_heads:
            out = out[:, :, :actual_num_heads, :]

        return out, lse

799
800
801
802
803
    def _bf16_flash_mla_kernel(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
804
    ) -> torch.Tensor:
805
806
        num_tokens = q.shape[0]
        kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
807
808
            -1, 1, kv_c_and_k_pe_cache.shape[-1]
        )
809
810
811

        # NOTE(Chen): kernel requires num_local_head to be a multiple of
        # 64 on hopper and 128 on blackwell
Lucas Wilkinson's avatar
Lucas Wilkinson committed
812
813
        if self.num_heads % self.prefill_padding != 0:
            assert self.prefill_padding % self.num_heads == 0
814
            logger.warning_once(
Lucas Wilkinson's avatar
Lucas Wilkinson committed
815
816
                f"Padding num_heads from {self.num_heads} to "
                f"{self.prefill_padding} for BF16 sparse prefill kernel"
817
            )
Lucas Wilkinson's avatar
Lucas Wilkinson committed
818
            q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
819
            q_padded[:, : self.num_heads, :] = q
820
821
822
            q = q_padded

        topk_indices = topk_indices.view(num_tokens, 1, -1)
823
        output = flash_mla_sparse_fwd(
824
825
826
            q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
        )[0]
        output = output[:, : self.num_heads, :]
827
828
        return output

829
    def forward_mqa(
830
        self,
831
832
833
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
834
        layer: AttentionLayer,
835
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
836
837
838
        # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
        # MQA 576/512 approach for both prefill and decode

839
840
        # Concatenate q if it's a tuple (ql_nope, q_pe)
        if isinstance(q, tuple):
841
842
843
            ql_nope, q_pe = q
            q = self.q_concat_buffer[: ql_nope.shape[0]]
            ops.concat_mla_q(ql_nope, q_pe, q)
844

845
        num_actual_toks = q.shape[0]
846

847
        # Get topk indices
848
        assert self.topk_indices_buffer is not None
849
        topk_indices = self.topk_indices_buffer[:num_actual_toks]
850

851
        use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
852

853
        if not use_fp8_cache:
854
855
856
            attn_out = self._forward_bf16_kv(
                q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
            )
857
858
        elif attn_metadata.fp8_use_mixed_batch:
            attn_out = self._forward_fp8_kv_mixed_batch(
859
                q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
860
            )
861
        else:
862
            attn_out = self._forward_fp8_kv_separate_prefill_decode(
863
                q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
864
            )
865

866
        return attn_out, None