common.py 86.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
# MLA Common Components

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
This file implements common components for MLA implementations.

First we define:

Sq      as Q sequence length
Skv     as KV sequence length

MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").

NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.

Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).

Deepseek's MLA attention works the following way:
Matthew Bonanni's avatar
Matthew Bonanni committed
27
* Use a single latent vector to represent the per-token entry of the KV cache.
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.

Below is example of both paths assuming batchsize = 1

## More Extent Definitions:

C           Context length, `Skv - Sq`
H           hidden size
N           number of attention heads
Lq          latent dimension for Q              1536 in DSV3
Lkv         latent dimension for K/V            512 in DSV3
P           nope dimension, no rope.            128 in DSV3
R           rope dimension, goes through rope.  64 in DSV3
V           V head dim.                         128 in DSV3

## Vector/Matrix Definitions

h_t         hidden states (input to attention)  shape [Sq, H]
q_c         latent/compressed Q                 shape [Sq, Lq]
q_nope      uncompressed Q (no-rope)            shape [Sq, N, P]
q_pe        uncompressed Q (rope)               shape [Sq, N, R]
kv_c        latent/compressed KV                shape [Skv, Lkv]
k_pe        decoupled k position embeddings     shape [Skv, R]
new_kv_c    new kv_c from current iter          shape [Sq, Lkv]
new_k_pe    new k_pe from current iter          shape [Sq, R]
cache_kv_c  cached k_c from previous iters      shape [C, Lkv]
cache_k_pe  cached k_pe from previous iters     shape [C, R]
W_DQ        project h_t to q_c                  shape [H, Lq]
W_UQ        project q_c to q_nope               shape [Lq, N * P]
W_QR        project q_c to q_pe                 shape [Lq, N * R]
W_DKV       project h_t to kv_c                 shape [H, Lkv]
60
61
62
W_UK        project kv_c to k_nope              shape [Lkv, N, P]
W_KR        project h_t to k_pe                 shape [H, R]
W_UV        project kv_c to v                   shape [Lkv, N, V]
63
64
65
66
67
68
69
70
71
72
73
74
W_O         project v to h_t                    shape [N * V, H]


## Compute Friendly Approach (i.e. "_forward_prefill"):

q_c      = h_t @ W_DQ
q_nope   = (q_c @ W_UQ).view(Sq, N, P)
q_pe     = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c     = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe     = torch.cat([new_k_pe, cache_k_pe], dim=0)
75
76
k_nope   = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v        = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
77
78
79
80
81
82
83
84

// MHA with QK headdim = P + R
//           V headdim = V
//      spda_o shape [Sq, N, V]
spda_o = scaled_dot_product_attention(
    torch.cat([q_nope, q_pe], dim=-1),
    torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
    v
Matthew Bonanni's avatar
Matthew Bonanni committed
85
)
86
87
88
return spda_o @ W_O

NOTE: in the actual code,
89
90
    `kv_b_proj` is [W_UK; W_UV] concatenated per head
    `q_b_proj` is [W_UQ; W_QR] concatenated per head
91
92
93
94
95
96
97
    `out_proj` is W_O


## Data-Movement Friendly Approach (i.e. "_forward_decode"):

Runtime
q_c      = h_t @ W_DQ
98
99
q_nope   = (q_c @ W_UQ).view(-1, N, P)
ql_nope  = einsum("snh,lnh->snl", q, W_UK)
100
101
102
103
104
105
106
107
108
109
110
111
q_pe     = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c     = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe     = torch.cat([new_k_pe, cache_k_pe], dim=0)

// MQA with QK headdim = Lkv + R
//           V headdim = Lkv
//      spda_o shape [Sq, N, Lkv]
// NOTE: this is less compute-friendly since Lkv > P
//       but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
112
    torch.cat([ql_nope, q_pe], dim=-1),
113
114
115
    torch.cat([kv_c, k_pe], dim=-1),
    kv_c
)
116
117
118

o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
119
120
121
122


## Chunked Prefill

Matthew Bonanni's avatar
Matthew Bonanni committed
123
124
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
125
126
127
128
129
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.

However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`

Matthew Bonanni's avatar
Matthew Bonanni committed
130
131
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
132
133
134
135
fixed workspace size.

The chunked prefill approach is as follows:

Matthew Bonanni's avatar
Matthew Bonanni committed
136
MCC        Max chunk of context to process per iter, computed dynamically,
137
138
139
140
141
142
143
           used to bound the memory usage

q_c        = h_t @ W_DQ
q_nope     = (q_c @ W_UQ).view(Sq, N, P)
q_pe       = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c   = h_t @ W_DKV
new_k_pe   = RoPE(h_t @ W_KR)
144
145
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v      = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
146
147
148
149
150
151
152
153
154
155
156
157

// MHA between queries and new KV
//     with QK headdim = P + R
//           V headdim = V
//    curr_o   shape [Sq, N, V]
//    curr_lse shape [N, Sq], this is just order FA returns
curr_o, curr_lse = scaled_dot_product_attention(
    torch.cat([q_nope, q_pe], dim=-1),
    torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
    new_v,
    casual=True,
    return_softmax_lse=True
Matthew Bonanni's avatar
Matthew Bonanni committed
158
)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
    chunk_start  = chunk_idx * MCC
    chunk_end    = min(chunk_start + MCC, C)
    Sc           = chunk_end - chunk_start
    cache_kv_c_chunk   = cache_kv_c[chunk_start:chunk_end]
    cache_k_pe_chunk   = cache_k_pe[chunk_start:chunk_end]
    cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
    cache_v_chunk      = (cache_kv_c_chunk @ W_UV).view(-1, N, V)

    chunk_o, chunk_lse = scaled_dot_product_attention(
        torch.cat([q_nope, q_pe], dim=-1),
        torch.cat([cache_k_nope_chunk,
                   cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
                   dim=-1),
        cache_v_chunk,
        casual=False,
        return_softmax_lse=True
    )

    curr_o, curr_lse = merge_attn_states(
        suffix_output=curr_o,
        suffix_lse=curr_lse,
        prefix_output=chunk_o,
        prefix_lse=chunk_lse,
    )

return curr_o @ W_O
"""

import functools
from abc import abstractmethod
192
from dataclasses import dataclass, field
193
from enum import Enum
194
from typing import ClassVar, Generic, TypeVar
195
196

import torch
197
from tqdm import tqdm
198
199

from vllm import _custom_ops as ops
200
from vllm import envs
201
from vllm._aiter_ops import rocm_aiter_ops
202
203
204
205
206
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionLayer,
    MLAAttentionImpl,
)
207
from vllm.attention.backends.utils import get_mla_dims
208
from vllm.attention.ops.common import cp_lse_ag_out_rs
209
from vllm.attention.ops.merge_attn_states import merge_attn_states
210
from vllm.attention.utils.fa_utils import get_flash_attn_version
211
from vllm.config import VllmConfig, get_current_vllm_config
212
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
213
from vllm.logger import init_logger
214
from vllm.model_executor.layers.batch_invariant import (
215
    vllm_is_batch_invariant,
216
)
217
218
219
220
221
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    LinearBase,
    UnquantizedLinearMethod,
)
Simon Mo's avatar
Simon Mo committed
222
from vllm.platforms import current_platform
223
from vllm.utils.flashinfer import has_nvidia_artifactory
224
from vllm.utils.math_utils import cdiv, round_down
225
226
227
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
228
    get_dcp_local_seq_lens,
229
230
231
232
    get_per_layer_parameters,
    infer_global_hyperparameters,
    split_decodes_and_prefills,
)
233
from vllm.v1.kv_cache_interface import AttentionSpec
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

class QueryLenSupport(Enum):
    """Defines the level of query length support for an attention backend's
    decode pipeline.

    - SINGLE_ONLY: Decode pipeline only supports single-token queries
                   (query_len=1)
    - UNIFORM: Decode pipeline supports uniform multi-token queries
               (all requests must have same query_len > 1)
    - VARLEN: Decode pipeline supports variable-length queries
              (mixed query lengths in same batch)
    """

    SINGLE_ONLY = "single_only"
    UNIFORM = "uniform"
    VARLEN = "varlen"


253
254
try:
    from vllm.vllm_flash_attn import flash_attn_varlen_func
255

256
    is_vllm_fa = True
257
258
except ImportError:
    # For rocm use upstream flash attention
259
260
    if current_platform.is_rocm():
        from flash_attn import flash_attn_varlen_func
261
    is_vllm_fa = False
262

263
264
try:
    from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
265
266
    from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache  # noqa: F401

267
268
    flashinfer_available = True
except ImportError:
269
    BatchPrefillWithRaggedKVCacheWrapper = object
270

271
    flashinfer_available = False
272
273


274
275
276
277
278
279
280
281
282
def dynamic_per_batched_tensor_quant(
    x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
    DTYPE_MAX = torch.finfo(dtype).max
    min_val, max_val = x.aminmax()
    amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
    scale = DTYPE_MAX / amax
    x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
    return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
283
284


285
286
logger = init_logger(__name__)

287
288
CUDNN_WORKSPACE_SIZE = 12800

289
290
291
292
293
294

class MLACommonBackend(AttentionBackend):
    accept_output_buffer: bool = True

    @staticmethod
    def get_name() -> str:
295
        return "TRITON_MLA"
296
297

    @staticmethod
298
    def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
299
300
301
302
303
304
305
306
        return MLACommonMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
307
        cache_dtype_str: str = "auto",
308
    ) -> tuple[int, ...]:
309
310
        return (num_blocks, block_size, head_size)

311
312
313
314
315
316
317
318
319
    @staticmethod
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
        # `stride_order` indicates the permutation that gets
        # us from `get_kv_cache_shape` to the actual memory layout we want.
        # (num_blocks, num_layers, block_size, head_size)
        return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2)

320
321
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
322
323
        return [576]

324
    @classmethod
325
326
    def is_mla(cls) -> bool:
        return True
327

328
329

@dataclass
330
class MLACommonPrefillMetadata:
331
    """Prefill Specific Metadata"""
332
333
334
335
336
337
338
339
340

    @dataclass
    class ChunkedContextMetadata:
        # New for MLA (compared to FlashAttention)
        # For handling chunked prefill
        cu_seq_lens: torch.Tensor
        starts: torch.Tensor
        seq_tot: list[int]
        max_seq_lens: list[int]
341
        seq_lens: torch.Tensor
342
        workspace: torch.Tensor
343
344
        token_to_seq: torch.Tensor
        chunk_total_token: list[int]
345

346
        # for mla DCP
347
348
349
        padded_local_chunk_seq_lens: list[list[int]] | None = None
        local_context_lens_allranks: list[list[int]] | None = None
        padded_local_cu_seq_lens: torch.Tensor | None = None
350
        cu_seq_lens_lst: list[list[int]] | None = None
351
        chunk_size: int | None = None
352

353
354
355
    block_table: torch.Tensor
    query_start_loc: torch.Tensor
    max_query_len: int
356
    chunked_context: ChunkedContextMetadata | None = None
357
    query_seq_lens: torch.Tensor | None = None
358

359

360
361
@dataclass
class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
362
363
    prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
    prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = field(
364
365
        default_factory=list
    )
366
367


368
369
@dataclass
class CudnnPrefillMetadata(MLACommonPrefillMetadata):
370
    class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
371
372
        seq_lens: torch.Tensor

373
    cudnn_workspace: torch.Tensor | None = None
374
375


376
377
378
379
@dataclass
class MLACommonDecodeMetadata:
    block_table: torch.Tensor
    seq_lens: torch.Tensor
380
    dcp_tot_seq_lens: torch.Tensor | None
381
382
383
384
385
386
387
388
389
390
391
392


D = TypeVar("D", bound=MLACommonDecodeMetadata)


@dataclass
class MLACommonMetadata(Generic[D]):
    """Metadata for MLACommon.

    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """
393

394
395
396
397
398
399
400
401
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

402
403
    num_reqs: int
    max_query_len: int
404
    max_seq_len: int
405

406
407
408
409
    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

410
411
412
413
414
415
    # New for MLA (compared to FlashAttention)
    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int

416
    # The dimension of the attention heads
417
    head_dim: int | None = None
418

419
420
421
422
423
424
425
    decode: D | None = None
    prefill: (
        MLACommonPrefillMetadata
        | FlashInferPrefillMetadata
        | CudnnPrefillMetadata
        | None
    ) = None
426
427

    def __post_init__(self):
428
429
430
431
        if self.head_dim is not None and not MLACommonBackend.supports_head_size(
            self.head_dim
        ):
            raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.")
432
433


434
M = TypeVar("M", bound=MLACommonMetadata)
435
A = TypeVar("A")
436
437


438
def use_flashinfer_prefill() -> bool:
439
    # For blackwell default to flashinfer prefill if it's available since
440
    # it is faster than FA2.
441
442
443
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
444
    return (
445
        not vllm_config.attention_config.disable_flashinfer_prefill
446
        and flashinfer_available
447
448
        and not vllm_config.attention_config.use_cudnn_prefill
        and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
449
450
        and current_platform.is_device_capability(100)
    )
451
452


453
def use_cudnn_prefill() -> bool:
454
455
456
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
457
458
    return (
        flashinfer_available
459
        and vllm_config.attention_config.use_cudnn_prefill
460
461
462
        and current_platform.is_device_capability(100)
        and has_nvidia_artifactory()
    )
463
464


465
466
def use_trtllm_ragged_deepseek_prefill() -> bool:
    """Check if TRT-LLM ragged DeepSeek prefill should be used."""
467
468
469
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
470
471
    return (
        flashinfer_available
472
        and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
473
474
        and current_platform.is_device_capability(100)
    )
475
476


477
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
478
479
480
481
    """
    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """
482

483
484
485
486
487
488
489
    # Defines the level of query length support for this backend.
    # - SINGLE_ONLY: Only single-token queries (no spec decode support)
    # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
    # - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
    # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when
    # speculative decoding is enabled.
    query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY
490
491
492
493

    # The threshold for reordering the batch into decode and prefill requests.
    # If > 1, the batch will be reordered such that requests with
    # query length <= threshold are classified as decode requests.
494
    # Use `query_len_support` (above) to set this automatically
495
    # when speculative decoding is enabled.
496
    reorder_batch_threshold: int = 1
497

498
    @staticmethod
499
    def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
500
501
502
503
504
505
        scheduler_config = vllm_config.scheduler_config
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config

        chunked_prefill_workspace_size = min(
            # Try for 8 full length request or at least 4 pages per-request
506
507
508
509
            max(
                8 * model_config.max_model_len,
                4 * scheduler_config.max_num_seqs * cache_config.block_size,
            ),
510
511
512
513
514
515
516
517
            # For long-context models try not to over-allocate limiting
            # kv-cache space, limiting it to 64k tokens,
            # which would result in the workspace being:
            #   2*(576)*(64*1024) = 144mb
            # (assuming 576 MLA head dim, and fp16)
            # which would result in up-projected context being
            #   2*(192*128)*(64*1024) = 3gb
            # (assuming 192 QK head dim, 128 heads, and fp16)
518
519
            64 * 1024,
        )
520
521
522
523

        # Enforce that we enough for at least 1 page per request
        chunked_prefill_workspace_size = max(
            chunked_prefill_workspace_size,
524
525
            scheduler_config.max_num_seqs * cache_config.block_size,
        )
526
527
528

        return chunked_prefill_workspace_size

529
530
531
532
533
534
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
535
        metadata_cls: type[M] | None = None,
536
        supports_dcp_with_varlen: bool = False,
537
538
539
540
    ):
        self.metadata_cls = (
            metadata_cls if metadata_cls is not None else MLACommonMetadata
        )
541
542
543
544
        self.kv_cache_spec = kv_cache_spec
        scheduler_config = vllm_config.scheduler_config
        self.model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
545
        self.compilation_config = vllm_config.compilation_config
546
        self.vllm_config = vllm_config
547
548
        self.device = device

549
        self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
550
        self.mla_dims = get_mla_dims(self.model_config)
551
        self.aot_schedule = current_platform.is_cuda()
552
553
554
555
556
557
558
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
559
        self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
560
        self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
561

562
        # Don't try to access the runner on AMD
563
        if self.aot_schedule:
564
            self.page_size = self.kv_cache_spec.block_size
565

566
        self.chunked_prefill_workspace_size = (
567
            self.determine_chunked_prefill_workspace_size(vllm_config)
568
        )
569

570
571
572
573
574
        if self.dcp_world_size > 1:
            # Note(hc): The local kvcache is incomplete when DCP is triggered,
            # an additional kvcache allgather across the DCP group is therefore
            # required, so the workspace has to be enlarged by 1/DCP relative
            # to the original TP allocation.
575
            assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0
576
            self.chunked_prefill_workspace = torch.empty(
577
578
579
580
581
                (
                    self.chunked_prefill_workspace_size
                    + self.chunked_prefill_workspace_size // self.dcp_world_size,
                    self.model_config.get_head_size(),
                ),
582
583
584
585
586
                dtype=self.model_config.dtype,
                device=device,
            )
        else:
            self.chunked_prefill_workspace = torch.empty(
587
588
589
590
                (
                    self.chunked_prefill_workspace_size,
                    self.model_config.get_head_size(),
                ),
591
592
593
                dtype=self.model_config.dtype,
                device=device,
            )
594
595

        self._use_cudnn_prefill = use_cudnn_prefill()
596
        self._use_fi_prefill = use_flashinfer_prefill()
597
        self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
598
599
        self.prefill_metadata_cls = (
            FlashInferPrefillMetadata
600
601
602
603
604
            if self._use_fi_prefill
            else CudnnPrefillMetadata
            if self._use_cudnn_prefill
            else MLACommonPrefillMetadata
        )
605
606
607

        if self._use_fi_prefill:
            self._workspace_buffer = torch.empty(
608
                envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
609
                dtype=torch.uint8,
610
                device=device,
611
            )
612

613
            self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
614
            self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
615
616

            self._global_hyperparameters = infer_global_hyperparameters(
617
618
                get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
            )
619

620
621
        if self._use_trtllm_ragged_prefill:
            self._workspace_buffer = torch.empty(
622
623
624
                envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
                dtype=torch.uint8,
                device=device,
625
            )
626

627
628
629
630
        if self._use_cudnn_prefill:
            self.cudnn_workspace = torch.empty(
                CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
                dtype=torch.int8,
631
                device=device,
632
633
            )

634
        supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
635
        self._init_reorder_batch_threshold(
636
            self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
637
638
        )

639
640
641
642
643
644
645
        # Validate consistency between query_len_support and reorder_batch_threshold
        if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
            assert self.reorder_batch_threshold == 1, (
                f"reorder_batch_threshold must be 1 when query_len_support is "
                f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
            )

646
647
648
649
650
651
652
653
654
655
    def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
        qo_indptr = prefill.query_start_loc

        has_context = False
        if prefill.chunked_context is not None:
            chunked_context = prefill.chunked_context
            has_context = True

        if self._fi_prefill_main is None:
            self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
656
657
                self._workspace_buffer, "NHD", backend="cutlass"
            )
658
659
660
661
662
663
664
665

        if has_context:
            num_chunks = chunked_context.cu_seq_lens.shape[0]
            # Allocate more prefill chunk wrappers if needed
            if len(self._fi_prefill_chunks) < num_chunks:
                for _ in range(len(self._fi_prefill_chunks), num_chunks):
                    self._fi_prefill_chunks.append(
                        BatchPrefillWithRaggedKVCacheWrapper(
666
667
668
                            self._workspace_buffer, "NHD", backend="cutlass"
                        )
                    )
669
670
671
            assert num_chunks <= len(self._fi_prefill_chunks)

        # In MLA, the non-latent num_qo_heads == num_kv_heads
672
        num_qo_heads = self.num_heads
673
674
675
676
677
678
        num_kv_heads = num_qo_heads

        # Sanity: Verify that num_kv_heads == 1 since it is latent space
        assert self.kv_cache_spec.num_kv_heads == 1

        # Get non-latent head_dim_qk and head_dim_vo
679
        head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        head_dim_vo = self.mla_dims.v_head_dim

        # For main run, qo_indptr == kv_indptr
        kv_indptr = qo_indptr.clone()

        # Prepare main prefill
        self._fi_prefill_main.plan(
            qo_indptr=qo_indptr,
            kv_indptr=kv_indptr,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim_qk=head_dim_qk,
            head_dim_vo=head_dim_vo,
            causal=True,  # This is main run
            sm_scale=self._global_hyperparameters.sm_scale,
            window_left=self._global_hyperparameters.window_left,
            logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
697
            q_data_type=self.model_config.dtype,
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
        )

        # Prepare context prefills
        if has_context:
            for i in range(num_chunks):
                kv_indptr_chunk = chunked_context.cu_seq_lens[i]

                self._fi_prefill_chunks[i].plan(
                    qo_indptr=qo_indptr,
                    kv_indptr=kv_indptr_chunk,
                    num_qo_heads=num_qo_heads,
                    num_kv_heads=num_kv_heads,
                    head_dim_qk=head_dim_qk,
                    head_dim_vo=head_dim_vo,
                    causal=False,  # This is context run
                    sm_scale=self._global_hyperparameters.sm_scale,
                    window_left=self._global_hyperparameters.window_left,
715
                    logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
716
                    q_data_type=self.model_config.dtype,
717
718
719
720
721
                )

        prefill.prefill_main = self._fi_prefill_main
        prefill.prefill_chunks = self._fi_prefill_chunks

722
723
724
725
726
727
728
729
    def _build_decode(
        self,
        block_table_tensor: torch.Tensor,
        seq_lens_cpu: torch.Tensor,
        seq_lens_device: torch.Tensor,
        query_start_loc_cpu: torch.Tensor,
        query_start_loc_device: torch.Tensor,
        num_decode_tokens: int,
730
        dcp_tot_seq_lens_device: torch.Tensor | None,
731
    ) -> MLACommonDecodeMetadata:
732
        return MLACommonDecodeMetadata(
733
            block_table=block_table_tensor,
734
            seq_lens=seq_lens_device,
735
            dcp_tot_seq_lens=dcp_tot_seq_lens_device,
736
737
        )

738
    def build_for_cudagraph_capture(
739
740
        self, common_attn_metadata: CommonAttentionMetadata
    ) -> M:
741
742
743
744
745
        """
        This method builds the metadata for full cudagraph capture.
        Currently, only decode is supported for full cudagraphs with MLA.
        """
        m = common_attn_metadata
746
747
        assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), (
            "MLA only supports decode-only full CUDAGraph capture. "
748
            "Make sure all cudagraph capture sizes <= max_num_seq."
749
        )
750

751
        assert m.max_query_len <= self.reorder_batch_threshold  # decode only
752
753
754

        return self.build(0, m)

755
756
757
758
759
760
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> M:
761
        num_reqs = common_attn_metadata.num_reqs
762
        num_tokens = common_attn_metadata.num_actual_tokens
763
        max_query_len = common_attn_metadata.max_query_len
764
        max_seq_len = common_attn_metadata.max_seq_len
765

Simon Mo's avatar
Simon Mo committed
766
767
768
        # Note(simon): be careful about the CPU <> GPU memory movement in this
        # function. We should avoid GPU -> CPU sync as much as possible because
        # it blocks on all previous kernels.
769
770
771
        device = self.device
        block_table_tensor = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping
772

773
        query_start_loc = common_attn_metadata.query_start_loc
774
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
775
        seq_lens = common_attn_metadata.seq_lens
776
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
777
        dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
778
        dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
Simon Mo's avatar
Simon Mo committed
779

780
781
        query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]

782
        num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
783

784
785
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(
786
787
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold,
788
                require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
789
790
            )
        )
791

792
793
794
        assert num_decodes + num_prefills == num_reqs
        assert num_decode_tokens + num_prefill_tokens == num_tokens

795
        prefill_metadata = None
796
797
        if num_prefills > 0:
            reqs_start = num_decodes  # prefill_start
798

799
            context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
Simon Mo's avatar
Simon Mo committed
800
801
            max_context_len_cpu = context_lens_cpu.max().item()
            num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
802
803
804
            prefill_query_start_loc = (
                query_start_loc[reqs_start:] - query_start_loc[reqs_start]
            )
805
806

            chunked_context_metadata = None
807
            if max_context_len_cpu > 0:
808
809
810
811
812
813
814
815
                # NOTE: it is recommend you read the `Chunked Prefill` section
                # in the comment at the top of the file before trying to
                # understand the following code

                # currently we allocate an equal amount of workspace for each
                # prefill in the batch, we could probably use a more advanced
                # algorithm here and allocate more workspace to prefills with
                # longer context lengths
816
817
818
                max_context_chunk = (
                    self.chunked_prefill_workspace_size // num_prefills_with_context_cpu
                )
819

820
821
                if self.aot_schedule:
                    # align max_context_chunk to page_size by rounding down,
822
823
824
                    # currently the `gather_and_maybe_dequant_cache` kernel
                    # cannot handle `context_chunk_starts` that are not aligned
                    # to page_size
825
                    max_context_chunk = round_down(max_context_chunk, self.page_size)
826
827

                assert max_context_chunk > 0
Simon Mo's avatar
Simon Mo committed
828
                num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
829
830
831
832
833

                # if `max_context_chunk = 256`, `num_chunks = 3`, and
                #   `num_prefills_with_context = 4`, create a tensor that looks
                # like
                #  [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
Simon Mo's avatar
Simon Mo committed
834
835
                # Note(simon): this is done in CPU because of downstream's
                # of `to_list`.
836
837
838
839
                chunk_starts = (
                    torch.arange(num_chunks, dtype=torch.int32)
                    .unsqueeze(1)
                    .expand(-1, num_prefills)
840
                    * max_context_chunk
841
842
843
844
                )
                chunk_ends = torch.min(
                    context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk
                )
845
                chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
Simon Mo's avatar
Simon Mo committed
846

847
848
849
850
851
852
                cu_seq_lens_cpu = torch.zeros(
                    num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
                )
                torch.cumsum(
                    chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
                )
853
854
855
856
857
858
859
860
861
862
863
864
865
                chunk_total_token = cu_seq_lens_cpu[:, -1]

                max_token_num_over_chunk = chunk_total_token.max().item()
                token_to_seq_tensor_cpu = torch.zeros(
                    [num_chunks, max_token_num_over_chunk], dtype=torch.int32
                )
                range_idx = torch.arange(num_prefills, dtype=torch.int32)
                for i in range(num_chunks):
                    chunk_token_to_seq_tensor = torch.repeat_interleave(
                        range_idx, chunk_seq_lens[i]
                    )
                    chunk_len = chunk_token_to_seq_tensor.shape[0]
                    token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor
866

867
                if self.dcp_world_size > 1:
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
                    local_context_lens_allranks = get_dcp_local_seq_lens(
                        context_lens_cpu,
                        self.dcp_world_size,
                        None,
                        self.dcp_local_block_size,
                    )
                    # Note(qcs): The max local context lengths
                    # padded to `dcp_local_block_size`.
                    padded_local_context_lens_cpu = (
                        cdiv(
                            context_lens_cpu,
                            self.dcp_virtual_block_size,
                        )
                        * self.dcp_local_block_size
                    )
883
884
885
886
887
888
                    # Note(hc): The above max_context_chunk already enforces
                    # block_size alignment, DCP just need the block_size can
                    # be divisible by dcp_world_size, because DCP use
                    # cp_gather_cache which not require `cp_chunk_starts`
                    # aligned to page_size.
                    assert max_context_chunk % self.dcp_world_size == 0
889
890
891
892
893
894
895
896
                    padded_local_max_context_chunk_across_ranks = (
                        cdiv(
                            max_context_chunk,
                            self.dcp_virtual_block_size,
                        )
                        * self.dcp_local_block_size
                    )
                    local_chunk_starts = (
897
898
899
                        torch.arange(num_chunks, dtype=torch.int32)
                        .unsqueeze(1)
                        .expand(-1, num_prefills)
900
                        * padded_local_max_context_chunk_across_ranks
901
                    )
902
903
904
905
                    local_chunk_ends = torch.min(
                        padded_local_context_lens_cpu.unsqueeze(0),
                        local_chunk_starts
                        + padded_local_max_context_chunk_across_ranks,
906
                    )
907
908
909
                    padded_local_chunk_seq_lens = (
                        local_chunk_ends - local_chunk_starts
                    ).clamp(min=0)
910

911
                    padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
912
913
914
                        num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
                    )
                    torch.cumsum(
915
                        padded_local_chunk_seq_lens,
916
                        dim=1,
917
                        out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
918
919
920
921
922
923
924
925
                        dtype=torch.int32,
                    )

                chunked_context_metadata_cls = (
                    CudnnPrefillMetadata.ChunkedContextMetadata
                    if self._use_cudnn_prefill
                    else MLACommonPrefillMetadata.ChunkedContextMetadata
                )
926
                if self.dcp_world_size > 1:
927
928
                    chunked_context_metadata = chunked_context_metadata_cls(
                        cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
929
930
                        starts=local_chunk_starts.to(device, non_blocking=True),
                        seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
931
932
                        max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
                        seq_lens=chunk_seq_lens,
933
934
935
936
                        token_to_seq=token_to_seq_tensor_cpu.to(
                            device, non_blocking=True
                        ),
                        chunk_total_token=chunk_total_token.tolist(),
937
                        workspace=self.chunked_prefill_workspace,
938
939
940
941
942
                        padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
                        local_context_lens_allranks=local_context_lens_allranks.tolist(),
                        padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
                            device, non_blocking=True
                        ),
943
                        cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
944
                        chunk_size=padded_local_max_context_chunk_across_ranks,
945
946
                    )
                else:
947
948
                    chunked_context_metadata = chunked_context_metadata_cls(
                        cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
949
950
951
952
                        starts=chunk_starts.to(device, non_blocking=True),
                        seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
                        max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
                        seq_lens=chunk_seq_lens,
953
954
955
956
                        token_to_seq=token_to_seq_tensor_cpu.to(
                            device, non_blocking=True
                        ),
                        chunk_total_token=chunk_total_token,
957
958
                        workspace=self.chunked_prefill_workspace,
                    )
959

960
961
962
                if self._use_cudnn_prefill:
                    chunked_context_metadata.seq_lens = chunk_seq_lens

963
964
965
966
                assert (
                    max(chunked_context_metadata.max_seq_lens)
                    <= self.chunked_prefill_workspace_size
                )
967

968
            prefill_metadata = self.prefill_metadata_cls(
969
                block_table=block_table_tensor[reqs_start:, ...],
970
                query_start_loc=prefill_query_start_loc,
Simon Mo's avatar
Simon Mo committed
971
                max_query_len=max_query_len,
972
973
974
                chunked_context=chunked_context_metadata,
            )

975
976
            if self._use_cudnn_prefill:
                assert isinstance(prefill_metadata, CudnnPrefillMetadata)
977
978
979
                prefill_metadata.query_seq_lens = (
                    prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
                )
980
981
                prefill_metadata.cudnn_workspace = self.cudnn_workspace

982
983
984
985
986
            if self._use_trtllm_ragged_prefill:
                prefill_metadata.query_seq_lens = (
                    prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
                )

987
        decode_metadata = None
988
        if num_decodes > 0:
989
990
991
992
993
994
            dcp_tot_seq_lens_device = None
            if self.dcp_world_size > 1:
                dcp_tot_seq_lens_device = seq_lens[:num_decodes]
                seq_lens_cpu = dcp_local_seq_lens_cpu
                seq_lens = dcp_local_seq_lens

995
            decode_metadata = self._build_decode(
996
                block_table_tensor=block_table_tensor[:num_decodes, ...],
997
998
                seq_lens_cpu=seq_lens_cpu[:num_decodes],
                seq_lens_device=seq_lens[:num_decodes],
999
1000
                query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
                query_start_loc_device=query_start_loc[: num_decodes + 1],
1001
                num_decode_tokens=num_decode_tokens,
1002
                dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
1003
1004
            )

1005
        attn_metadata = self.metadata_cls(
1006
1007
            num_reqs=common_attn_metadata.num_reqs,
            max_query_len=common_attn_metadata.max_query_len,
1008
            max_seq_len=max_seq_len,
1009
            num_actual_tokens=num_tokens,
1010
1011
            query_start_loc=query_start_loc,
            slot_mapping=slot_mapping,
1012
            head_dim=self.model_config.get_head_size(),
1013
            # MLACommonMetadata Chunk prefill specific
1014
1015
1016
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
1017
1018
            prefill=prefill_metadata,
            decode=decode_metadata,
1019
1020
        )

1021
        if self._use_fi_prefill and num_prefills > 0:
1022
1023
1024
1025
1026
            assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
            self._build_fi_prefill_wrappers(attn_metadata.prefill)

        return attn_metadata

1027

1028
1029
1030
def reorg_kvcache(
    allgatered_kv_c_normed: torch.Tensor,
    allgatered_k_pe: torch.Tensor,
1031
1032
    padded_local_chunk_seq_lens_lst: list[int],
    local_context_lens_allranks: list[list[int]],
1033
1034
1035
1036
1037
1038
1039
    sum_seq_len: int,
    max_seq_len: int,
    chunk_size: int,
    chunk_idx: int,
    toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
1040
1041
1042
1043
1044
1045
    reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
    e.g.
    allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
                              T0_4, T0_5, pad, pad, T1_2, pad, ...]
    -> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
                                  T1_0, T1_1, T1_2, ...]
1046
    Args:
1047
1048
1049
        padded_local_chunk_seq_lens_lst: local chunk context lengths
            under current CP rank.
        local_context_lens_allranks: local context lengths on each CP rank.
1050
1051
        sum_seq_len: the sum of cp_chunk_seq_lens_lst.
        max_seq_len: the max value of cp_chunk_seq_lens_lst.
1052
        chunk_size: the local padded max context chunk from
1053
1054
1055
1056
1057
1058
1059
1060
            chunked_context_metadata building.
        chunk_idx: chunk idx of chunked_prefill.
        toks: the number of tokens for local gather cache.
    """
    kv_c_segments = []
    k_pe_segments = []
    src_token_idx = 0
    max_seq_len_check = 0
1061
1062
    for padded_local_chunk_seq_len, local_context_lens in zip(
        padded_local_chunk_seq_lens_lst, local_context_lens_allranks
1063
    ):
1064
        cur_seq_len = 0
1065
        for rank, local_context_len in enumerate(local_context_lens):
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
            # Note(qcs): We split the context into multiple chunks,
            # depending on the size of the workspace.
            # local_context in dcp0:   |-----------------|
            # local_context in dcp1:   |--------------|
            # n*padded_local_chunk:    |-----|-----|-----|
            # local_chunk_len in dcp1: |-----|-----|--|
            # so we need update the last chunk length in dcp1.
            local_chunk_len = min(
                max(0, local_context_len - chunk_idx * chunk_size),
                padded_local_chunk_seq_len,
            )
            if local_chunk_len != 0:
1078
1079
1080
                kv_c_segment = allgatered_kv_c_normed[
                    rank * toks + src_token_idx : rank * toks
                    + src_token_idx
1081
                    + local_chunk_len
1082
1083
1084
1085
                ]
                k_pe_segment = allgatered_k_pe[
                    rank * toks + src_token_idx : rank * toks
                    + src_token_idx
1086
                    + local_chunk_len
1087
                ]
1088
1089
                kv_c_segments.append(kv_c_segment)
                k_pe_segments.append(k_pe_segment)
1090
                cur_seq_len += local_chunk_len
1091
        max_seq_len_check = max(max_seq_len_check, cur_seq_len)
1092
        src_token_idx += padded_local_chunk_seq_len
1093
1094
1095
1096
1097
1098
1099
1100
    reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
    reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
    assert reorganized_kv_c_normed.shape[0] == sum_seq_len
    assert reorganized_k_pe.shape[0] == sum_seq_len
    assert max_seq_len_check == max_seq_len
    return reorganized_kv_c_normed, reorganized_k_pe


1101
1102
1103
# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
# and MLACommonImpl -> MLACommonDenseImpl or somthing like that
class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
    """
    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
1115
1116
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
1117
        kv_cache_dtype: str,
1118
        logits_soft_cap: float | None,
1119
        attn_type: str,
1120
        kv_sharing_target_layer_name: str | None,
1121
        # MLA Specific Arguments
1122
        q_lora_rank: int | None,
1123
1124
1125
1126
1127
1128
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
1129
        indexer=None,
1130
        q_pad_num_heads: int | None = None,
1131
    ) -> None:
1132
1133
1134
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported for MLA")

1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_head_dim
        self.v_head_dim = v_head_dim
        self.kv_b_proj = kv_b_proj
1148
        self.indexer = indexer
1149
        self.q_pad_num_heads = q_pad_num_heads
1150
        self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
1151

1152
1153
1154
1155
1156
1157
1158
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        def get_layer_weight(layer):
            WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
            for attr in WEIGHT_NAMES:
                if hasattr(layer, attr):
                    return getattr(layer, attr)
            raise AttributeError(
1159
1160
                f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
            )
1161
1162
1163
1164

        def get_and_maybe_dequant_weights(layer: LinearBase):
            if not isinstance(layer.quant_method, UnquantizedLinearMethod):
                # NOTE: This should only be used offline, since it's O(N^3)
1165
1166
1167
1168
1169
1170
                eye = torch.eye(
                    layer.input_size_per_partition,
                    dtype=act_dtype,
                    device=get_layer_weight(layer).device,
                )
                dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
                del eye
                # standardize to (output, input)
                return dequant_weights.T
            return layer.weight

        # we currently do not have quantized bmm's which are needed for
        # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
        # the bmm's in 16-bit, the extra memory overhead of this is fairly low
        kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
        assert kv_b_proj_weight.shape == (
            self.kv_lora_rank,
1182
1183
1184
1185
1186
1187
1188
1189
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
        ), (
            f"{kv_b_proj_weight.shape=}, "
            f"{self.kv_lora_rank=}, "
            f"{self.num_heads=}, "
            f"{self.qk_nope_head_dim=}, "
            f"{self.v_head_dim=}"
        )
1190
1191
1192
1193
1194
1195
1196
        kv_b_proj_weight = kv_b_proj_weight.view(
            self.kv_lora_rank,
            self.num_heads,
            self.qk_nope_head_dim + self.v_head_dim,
        )

        W_UK, W_UV = kv_b_proj_weight.split(
1197
1198
            [self.qk_nope_head_dim, self.v_head_dim], dim=-1
        )
1199

1200
        if self.is_aiter_triton_fp8_bmm_enabled:
1201
1202
1203
            W_K = W_UK.transpose(0, 1)  # 16 512 128
            W_V = W_UV.permute(1, 2, 0)  # 16 128 512
            self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
1204
1205
                W_K, dtype=current_platform.fp8_dtype()
            )
1206
            self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
1207
1208
                W_V, dtype=current_platform.fp8_dtype()
            )
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223

            # The kernel operates on non-padded inputs. Hence, pre-compiling
            # triton kernel to avoid runtime compilation for unseen batch sizes
            # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
            # On DS-R1, this step adds roughly 50s to the model loading time.
            max_batch_size = 1024  # [ToDo] Find the optimal upper limit
            pre_compilation_list = list(range(1, max_batch_size + 1))
            if is_global_first_rank():
                pre_compilation_list = tqdm(
                    pre_compilation_list,
                    desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
                    total=max_batch_size,
                )

            for m in pre_compilation_list:
1224
1225
1226
1227
1228
                x = torch.empty(
                    (self.W_K.shape[0], m, self.W_K.shape[2]),
                    dtype=torch.bfloat16,
                    device=self.W_K.device,
                )
1229
                rocm_aiter_ops.triton_fp8_bmm(
1230
1231
1232
1233
1234
1235
1236
1237
                    x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
                )

                x = torch.empty(
                    (self.W_V.shape[0], m, self.W_V.shape[2]),
                    dtype=torch.bfloat16,
                    device=self.W_V.device,
                )
1238
                rocm_aiter_ops.triton_fp8_bmm(
1239
1240
                    x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
                )
1241
1242
1243
1244
1245
1246
1247
1248
1249
        else:
            # Convert from (L, N, V) to (N, L, V)
            self.W_UV = W_UV.transpose(0, 1)
            # Convert from (L, N, P) to (N, P, L)
            self.W_UK_T = W_UK.permute(1, 2, 0)

    def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
        # Convert from (B, N, L) to (N, B, L)
        x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
1250

1251
        if self.is_aiter_triton_fp8_bmm_enabled:
1252
            out = out.view(-1, self.num_heads, self.v_head_dim)
1253
            # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
1254
            x = rocm_aiter_ops.triton_fp8_bmm(
1255
                x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
1256
            )
1257
1258
1259
1260
1261
1262
1263
1264
        else:
            # Convert from (B, N * V) to (N, B, V)
            out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)

            # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
            torch.bmm(x, self.W_UV, out=out)  # Reuse "out" to make it "hot"

            # Convert from (N, B, V) to (B, N * V)
1265
            out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
1266
1267
1268
1269
1270
1271
1272

            # Adjust output buffer shape back to the original (B, N * V)
            N, B, V = out.shape
            out.resize_((B, N * V))
            out.copy_(out_new)  # Copy result


1273
1274
1275
1276
1277
1278
1279
1280
class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
    """
    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
1281

1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
    def process_weights_after_loading(self, act_dtype: torch.dtype):

        def get_layer_weight(layer):
            WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
            for attr in WEIGHT_NAMES:
                if hasattr(layer, attr):
                    return getattr(layer, attr)
            raise AttributeError(
                f"Layer '{layer}' has no recognized weight attribute:"
                f" {WEIGHT_NAMES}.")

        def get_and_maybe_dequant_weights(layer: LinearBase):
            if not isinstance(layer.quant_method, UnquantizedLinearMethod):
                # NOTE: This should only be used offline, since it's O(N^3)
                eye = torch.eye(layer.input_size_per_partition,
                                dtype=act_dtype,
                                device=get_layer_weight(layer).device)
                dequant_weights = layer.quant_method.apply(layer,
                                                           eye,
                                                           bias=None)
                del eye
                # standardize to (output, input)
                return dequant_weights.T
            return layer.weight

        # we currently do not have quantized bmm's which are needed for
        # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
        # the bmm's in 16-bit, the extra memory overhead of this is fairly low
        kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
        assert kv_b_proj_weight.shape == (
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
                f"{kv_b_proj_weight.shape=}, "
                f"{self.kv_lora_rank=}, "
                f"{self.num_heads=}, "
                f"{self.qk_nope_head_dim=}, "
                f"{self.v_head_dim=}")
        kv_b_proj_weight = kv_b_proj_weight.view(
            self.kv_lora_rank,
            self.num_heads,
            self.qk_nope_head_dim + self.v_head_dim,
        )

        W_UK, W_UV = kv_b_proj_weight.split(
            [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        if is_rocm_aiter_fp8bmm_enabled():
            W_K = W_UK.transpose(0, 1)  # 16 512 128
            W_V = W_UV.permute(1, 2, 0)  # 16 128 512
            self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
                W_K, dtype=current_platform.fp8_dtype())
            self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
                W_V, dtype=current_platform.fp8_dtype())

            # The kernel operates on non-padded inputs. Hence, pre-compiling
            # triton kernel to avoid runtime compilation for unseen batch sizes
            # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
            # On DS-R1, this step adds roughly 50s to the model loading time.
            max_batch_size = 1024  # [ToDo] Find the optimal upper limit
            pre_compilation_list = list(range(1, max_batch_size + 1))
            if is_global_first_rank():
                pre_compilation_list = tqdm(
                    pre_compilation_list,
                    desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
                    total=max_batch_size,
                )

            for m in pre_compilation_list:
                x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
                                dtype=torch.bfloat16,
                                device=self.W_K.device)
                aiter_triton_fp8_bmm(x,
                                     self.W_K,
                                     self.W_K_scale,
                                     group_size=128,
                                     transpose_bm=True)

                x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
                                dtype=torch.bfloat16,
                                device=self.W_V.device)
                aiter_triton_fp8_bmm(x,
                                     self.W_V,
                                     self.W_V_scale,
                                     group_size=128,
                                     transpose_bm=True)
        else:
            # Convert from (L, N, V) to (N, L, V)
            self.W_UV = W_UV.transpose(0, 1)
            # Convert from (L, N, P) to (N, P, L)
            self.W_UK_T = W_UK.permute(1, 2, 0)

    def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
        # Convert from (B, N, L) to (N, B, L)
        x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
        if is_rocm_aiter_fp8bmm_enabled():
            # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
            x = aiter_triton_fp8_bmm(x,
                                     self.W_V,
                                     self.W_V_scale,
                                     group_size=128,
                                     transpose_bm=True)
            # Convert from (B, N, V) to (B, N * V)
            x = x.reshape(-1, self.num_heads * self.v_head_dim)
            # Copy result
            out.copy_(x)
        else:
            # Convert from (B, N * V) to (N, B, V)
            out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)

            # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
            torch.bmm(x, self.W_UV, out=out)  # Reuse "out" to make it "hot"

            # Convert from (N, B, V) to (B, N * V)
            out_new = out.transpose(0, 1).reshape(
                -1, self.num_heads * self.v_head_dim)

            # Adjust output buffer shape back to the original (B, N * V)
            N, B, V = out.shape
            out.resize_((B, N * V))
            out.copy_(out_new)  # Copy result


class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
    """
    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

1413
1414
1415
1416
1417
        if use_flashinfer_prefill():
            logger.debug_once("Using FlashInfer prefill for MLA")
            self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
            self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
            self._pad_v = False
1418
1419
1420
1421
1422
1423
1424
        elif use_trtllm_ragged_deepseek_prefill():
            logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
            self._run_prefill_context_chunk = (
                self._run_prefill_context_chunk_trtllm_ragged
            )
            self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
            self._pad_v = False
1425
1426
        elif use_cudnn_prefill():
            logger.debug_once("Using CUDNN prefill for MLA")
1427
            self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
1428
1429
            self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
            self._pad_v = False
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
        else:  # Use FlashAttention
            logger.debug_once("Using FlashAttention prefill for MLA")
            self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
            self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa

            # Handle the differences between the flash_attn_varlen from
            # flash_attn and the one from vllm_flash_attn. The former is used on
            # RoCM and the latter has an additional parameter to control
            # FA2 vs FA3
            self.flash_attn_varlen_func = flash_attn_varlen_func
            self.vllm_flash_attn_version = get_flash_attn_version()
            if self.vllm_flash_attn_version is not None:
1442
1443
1444
                self.flash_attn_varlen_func = functools.partial(
                    flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
                )
1445
1446
1447
1448
1449
1450
1451

            # For MLA the v head dim is smaller than qk head dim so we pad out
            # v with 0s to match the qk head dim for attention backends that do
            # not support different headdims
            # We don't need to pad V if we are on a hopper system with FA3
            self._pad_v = self.vllm_flash_attn_version is None or not (
                self.vllm_flash_attn_version == 3
1452
1453
                and current_platform.get_device_capability()[0] == 9
            )
1454

1455
        self.dcp_world_size: int | None = None
1456

1457
        self.chunked_prefill_workspace_size = (
1458
            MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
1459
1460
1461
                get_current_vllm_config()
            )
        )
1462
1463
        self.cp_kv_cache_interleave_size: int = (
            get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
1464
        )
1465
1466
1467
1468

    def _flash_attn_varlen_diff_headdims(
        self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
    ):
1469
1470
1471
        maybe_padded_v = v
        if self._pad_v:
            maybe_padded_v = torch.nn.functional.pad(
1472
1473
                v, [0, q.shape[-1] - v.shape[-1]], value=0
            )
1474

1475
1476
1477
1478
1479
1480
        if is_vllm_fa:
            kwargs["return_softmax_lse"] = return_softmax_lse
        else:
            # ROCm leverages the upstream flash_attn, which takes a parameter
            # called "return_attn_probs" instead of return_softmax_lse
            kwargs["return_attn_probs"] = return_softmax_lse
1481
        if vllm_is_batch_invariant():
1482
            kwargs["num_splits"] = 1
1483

1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
        attn_out = self.flash_attn_varlen_func(
            q=q,
            k=k,
            v=maybe_padded_v,
            softmax_scale=softmax_scale,
            **kwargs,
        )

        # Unpack the output if there is multiple results
        lse = None
        if isinstance(attn_out, tuple):
            attn_out, lse = attn_out[0], attn_out[1]

        # Remain consistent with old `flash_attn_varlen_func` where there
        # is only one output tensor if `return_softmax_lse` is False.
        if return_softmax_lse:
            return attn_out, lse
        return attn_out

1503
1504
1505
    def _run_prefill_new_tokens_fa(
        self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
    ):
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
        return self._flash_attn_varlen_diff_headdims(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=prefill.query_start_loc,
            cu_seqlens_k=prefill.query_start_loc,
            max_seqlen_q=prefill.max_query_len,
            max_seqlen_k=prefill.max_query_len,
            softmax_scale=self.scale,
            causal=True,
            return_softmax_lse=return_softmax_lse,
        )

1519
1520
1521
    def _run_prefill_new_tokens_fi(
        self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
    ):
1522
1523
        assert isinstance(prefill, FlashInferPrefillMetadata)
        assert prefill.prefill_main is not None
1524

1525
        ret = prefill.prefill_main.run(
1526
1527
1528
1529
1530
1531
            q=q,
            k=k,
            v=v,
            return_lse=return_softmax_lse,
        )

1532
1533
1534
1535
        if isinstance(ret, tuple):
            return ret[0], ret[1].transpose(0, 1).contiguous()
        return ret

1536
1537
1538
    def _run_prefill_new_tokens_cudnn(
        self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
    ):
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
        assert isinstance(prefill, CudnnPrefillMetadata)
        assert prefill.query_seq_lens is not None
        output, lse = cudnn_batch_prefill_with_kv_cache(
            q=q,
            k_cache=k,
            v_cache=v,
            scale=self.scale,
            workspace_buffer=prefill.cudnn_workspace,
            max_token_per_sequence=prefill.max_query_len,
            max_sequence_kv=prefill.max_query_len,
            actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
            actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
            causal=True,
1552
1553
1554
1555
            # Do not support False for now
            return_lse=True,
            # Indicates actual_seq_lens are on GPU or CPU.
            is_cuda_graph_compatible=True,
1556
1557
1558
1559
1560
        )
        if return_softmax_lse:
            return output, lse
        return output

1561
1562
1563
    def _run_prefill_context_chunk_fa(
        self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
    ):
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
        assert prefill.chunked_context is not None
        return self._flash_attn_varlen_diff_headdims(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=prefill.query_start_loc,
            cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx],
            max_seqlen_q=prefill.max_query_len,
            max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx],
            softmax_scale=self.scale,
            causal=False,  # Context is unmasked
            return_softmax_lse=True,
        )

1578
1579
1580
    def _run_prefill_context_chunk_fi(
        self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
    ):
1581
        assert isinstance(prefill, FlashInferPrefillMetadata)
1582

1583
        attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
1584
1585
1586
1587
1588
            q=q,
            k=k,
            v=v,
            return_lse=True,
        )
1589
1590
        # Convert from (q_len, num_heads) to (num_heads, q_len)
        return attn_out, lse.transpose(0, 1).contiguous()
1591

1592
1593
        # Convert from (q_len, num_heads) to (num_heads, q_len)
        return attn_out, lse.transpose(0, 1).contiguous()
1594

1595
1596
1597
    def _run_prefill_context_chunk_cudnn(
        self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
    ):
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
        assert isinstance(prefill, CudnnPrefillMetadata)
        assert prefill.chunked_context is not None
        assert prefill.chunked_context.seq_lens[chunk_idx] is not None
        assert prefill.query_seq_lens is not None
        return cudnn_batch_prefill_with_kv_cache(
            q=q,
            k_cache=k,
            v_cache=v,
            scale=self.scale,
            workspace_buffer=prefill.cudnn_workspace,
            max_token_per_sequence=prefill.max_query_len,
            max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
            actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
1611
1612
1613
            actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view(
                -1, 1, 1, 1
            ),
1614
1615
            causal=False,
            return_lse=True,
1616
1617
            # Indicates actual_seq_lens are on GPU or CPU.
            is_cuda_graph_compatible=True,
1618
1619
        )

1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
    def _run_prefill_new_tokens_trtllm_ragged(
        self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
    ):
        """TRT-LLM ragged attention for new tokens (causal)."""
        from flashinfer.prefill import trtllm_ragged_attention_deepseek

        assert prefill.query_seq_lens is not None

        ret = trtllm_ragged_attention_deepseek(
            query=q,
            key=k,
            value=v,
            workspace_buffer=self._workspace_buffer,
            seq_lens=prefill.query_seq_lens,
            max_q_len=prefill.max_query_len,
            max_kv_len=prefill.max_query_len,
            bmm1_scale=self.scale,
            bmm2_scale=1.0,
            o_sf_scale=1.0,
            batch_size=prefill.query_seq_lens.shape[0],
            window_left=-1,
            cum_seq_lens_q=prefill.query_start_loc,
            cum_seq_lens_kv=prefill.query_start_loc,
            enable_pdl=False,
            is_causal=True,
            return_lse=return_softmax_lse,
        )
1647

1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
        if isinstance(ret, tuple):
            # Convert from (q_len, num_heads) to (num_heads, q_len)
            return ret[0], ret[1].transpose(0, 1).contiguous()
        return ret

    def _run_prefill_context_chunk_trtllm_ragged(
        self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
    ):
        """TRT-LLM ragged attention for context chunks (non-causal)."""
        from flashinfer.prefill import trtllm_ragged_attention_deepseek

        assert prefill.chunked_context is not None
        assert prefill.chunked_context.seq_lens[chunk_idx] is not None

        out = torch.zeros(
            q.shape[0],
            q.shape[1],
            v.shape[2],
            device=q.device,
            dtype=q.dtype,
        )
        self._workspace_buffer.fill_(0)

        attn_out, lse = trtllm_ragged_attention_deepseek(
            query=q,
            key=k,
            value=v,
            workspace_buffer=self._workspace_buffer,
            seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
            max_q_len=prefill.max_query_len,
            max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
            bmm1_scale=self.scale,
            bmm2_scale=1.0,
            o_sf_scale=1.0,
            batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0],
            window_left=-1,
            cum_seq_lens_q=prefill.query_start_loc,
            cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx],
            enable_pdl=False,
            is_causal=False,
            return_lse=True,
            out=out,
        )

        # Convert from (q_len, num_heads) to (num_heads, q_len)
        return attn_out, lse.transpose(0, 1).contiguous()

1695
    def process_weights_after_loading(self, act_dtype: torch.dtype):
1696
        def get_layer_weight(layer):
1697
1698
1699
1700
1701
            WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
            for attr in WEIGHT_NAMES:
                if hasattr(layer, attr):
                    return getattr(layer, attr)
            raise AttributeError(
1702
1703
                f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
            )
1704
1705
1706
1707

        def get_and_maybe_dequant_weights(layer: LinearBase):
            if not isinstance(layer.quant_method, UnquantizedLinearMethod):
                # NOTE: This should only be used offline, since it's O(N^3)
1708
1709
1710
1711
1712
1713
                eye = torch.eye(
                    layer.input_size_per_partition,
                    dtype=act_dtype,
                    device=get_layer_weight(layer).device,
                )
                dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
1714
1715
1716
1717
1718
                del eye
                # standardize to (output, input)
                return dequant_weights.T
            return layer.weight

1719
        # we currently do not have quantized bmm's which are needed for
1720
        # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
1721
        # the bmm's in 16-bit, the extra memory overhead of this is fairly low
1722
1723
1724
        kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
        assert kv_b_proj_weight.shape == (
            self.kv_lora_rank,
1725
1726
1727
1728
1729
1730
1731
1732
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
        ), (
            f"{kv_b_proj_weight.shape=}, "
            f"{self.kv_lora_rank=}, "
            f"{self.num_heads=}, "
            f"{self.qk_nope_head_dim=}, "
            f"{self.v_head_dim=}"
        )
1733
1734
1735
1736
1737
1738
1739
        kv_b_proj_weight = kv_b_proj_weight.view(
            self.kv_lora_rank,
            self.num_heads,
            self.qk_nope_head_dim + self.v_head_dim,
        )

        W_UK, W_UV = kv_b_proj_weight.split(
1740
1741
            [self.qk_nope_head_dim, self.v_head_dim], dim=-1
        )
1742

1743
        if self.is_aiter_triton_fp8_bmm_enabled:
1744
1745
1746
            W_K = W_UK.transpose(0, 1)  # 16 512 128
            W_V = W_UV.permute(1, 2, 0)  # 16 128 512
            self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
1747
1748
                W_K, dtype=current_platform.fp8_dtype()
            )
1749
            self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
1750
1751
                W_V, dtype=current_platform.fp8_dtype()
            )
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766

            # The kernel operates on non-padded inputs. Hence, pre-compiling
            # triton kernel to avoid runtime compilation for unseen batch sizes
            # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
            # On DS-R1, this step adds roughly 50s to the model loading time.
            max_batch_size = 1024  # [ToDo] Find the optimal upper limit
            pre_compilation_list = list(range(1, max_batch_size + 1))
            if is_global_first_rank():
                pre_compilation_list = tqdm(
                    pre_compilation_list,
                    desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
                    total=max_batch_size,
                )

            for m in pre_compilation_list:
1767
1768
1769
1770
1771
                x = torch.empty(
                    (self.W_K.shape[0], m, self.W_K.shape[2]),
                    dtype=torch.bfloat16,
                    device=self.W_K.device,
                )
1772
                rocm_aiter_ops.triton_fp8_bmm(
1773
1774
                    x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
                )
1775

1776
1777
1778
1779
1780
                x = torch.empty(
                    (self.W_V.shape[0], m, self.W_V.shape[2]),
                    dtype=torch.bfloat16,
                    device=self.W_V.device,
                )
1781
                rocm_aiter_ops.triton_fp8_bmm(
1782
1783
                    x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
                )
1784
1785
1786
1787
1788
        else:
            # Convert from (L, N, V) to (N, L, V)
            self.W_UV = W_UV.transpose(0, 1)
            # Convert from (L, N, P) to (N, P, L)
            self.W_UK_T = W_UK.permute(1, 2, 0)
1789
1790
1791
1792
1793
1794

    def _compute_prefill_context(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
1795
        k_scale: torch.Tensor,
1796
    ):
1797
1798
1799
        assert attn_metadata.prefill is not None
        prefill_metadata = attn_metadata.prefill
        assert prefill_metadata.chunked_context is not None
1800
1801

        output = None
1802
1803
        iters = len(prefill_metadata.chunked_context.seq_tot)
        workspace = prefill_metadata.chunked_context.workspace
1804
        for i in range(iters):
1805
            toks = prefill_metadata.chunked_context.seq_tot[i]
1806
            ops.gather_and_maybe_dequant_cache(
1807
1808
                src_cache=kv_c_and_k_pe_cache,
                dst=workspace,
1809
1810
                block_table=prefill_metadata.block_table,
                cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
1811
1812
                token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
                num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
1813
1814
                kv_cache_dtype=self.kv_cache_dtype,
                scale=k_scale,
1815
                seq_starts=prefill_metadata.chunked_context.starts[i],
1816
1817
            )

1818
1819
            kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
            k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
1820

1821
1822
1823
1824
            kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
                -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1825

1826
            k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1827

1828
1829
1830
            attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
                prefill=prefill_metadata,
                chunk_idx=i,
1831
1832
                q=q,
                k=k,
1833
                v=v,
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
            )

            if output is None:
                output = attn_output
                output_lse = attn_softmax_lse
            else:
                output_tmp = torch.empty_like(output)
                output_lse_tmp = torch.empty_like(output_lse)
                merge_attn_states(
                    output=output_tmp,
                    output_lse=output_lse_tmp,
                    prefix_output=output,
                    prefix_lse=output_lse,
                    suffix_output=attn_output,
                    suffix_lse=attn_softmax_lse,
                )
                output = output_tmp
                output_lse = output_lse_tmp

        return output, output_lse

1855
1856
1857
1858
1859
1860
1861
1862
    def _context_parallel_compute_prefill_context(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
        k_scale: torch.Tensor,
        dcp_world_size: int,
    ):
co63oc's avatar
co63oc committed
1863
        assert k_scale is None, "DCP not support scaled kvcache now."
1864
1865
1866
        assert attn_metadata.prefill is not None
        prefill_metadata = attn_metadata.prefill
        assert prefill_metadata.chunked_context is not None
1867
1868
1869
        assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None
        assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
        assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
1870
        assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
1871
        assert prefill_metadata.chunked_context.chunk_size is not None
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882

        output = None
        iters = len(prefill_metadata.chunked_context.seq_tot)
        workspace = prefill_metadata.chunked_context.workspace

        for i in range(iters):
            toks = prefill_metadata.chunked_context.seq_tot[i]
            ops.cp_gather_cache(
                src_cache=kv_c_and_k_pe_cache,
                dst=workspace,
                block_table=prefill_metadata.block_table,
1883
1884
1885
                cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
                    i
                ],
1886
1887
1888
1889
1890
1891
1892
                batch_size=attn_metadata.num_prefills,
                seq_starts=prefill_metadata.chunked_context.starts[i],
            )
            # workspace
            # |------- N tokens --------|--------- N*dcp_size tokens ----------|
            # |<- use for loca_gather ->|<--------- use for allgather -------->|
            allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
1893
            assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0]
1894
1895
1896
            assert toks <= allgather_offset
            local_gathered_kvcache = workspace[:toks]
            cur_allgather_workspace = workspace[
1897
1898
                allgather_offset : allgather_offset * (1 + dcp_world_size)
            ]
1899
            assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
            cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size]
            cur_allgather_kvcache.copy_(
                get_dcp_group().all_gather(local_gathered_kvcache, dim=0)
            )
            assert (
                cur_allgather_kvcache.shape[-1]
                == self.kv_lora_rank + self.qk_rope_head_dim
            )
            allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze(
                1
            ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1911
1912
1913
1914

            kv_c_normed, k_pe = reorg_kvcache(
                allgatered_kv_c_normed,
                allgatered_k_pe,
1915
                padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[
1916
1917
                    i
                ],
1918
                local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
1919
                sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
1920
1921
1922
                max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
                chunk_size=prefill_metadata.chunked_context.chunk_size,
                chunk_idx=i,
1923
1924
                toks=toks,
            )
1925

1926
1927
1928
1929
1930
            kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
                -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958

            attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
                prefill=prefill_metadata,
                chunk_idx=i,
                q=q,
                k=k,
                v=v,
            )

            if output is None:
                output = attn_output
                output_lse = attn_softmax_lse
            else:
                output_tmp = torch.empty_like(output)
                output_lse_tmp = torch.empty_like(output_lse)
                merge_attn_states(
                    output=output_tmp,
                    output_lse=output_lse_tmp,
                    prefix_output=output,
                    prefix_lse=output_lse,
                    suffix_output=attn_output,
                    suffix_lse=attn_softmax_lse,
                )
                output = output_tmp
                output_lse = output_lse_tmp

        return output, output_lse

1959
1960
1961
1962
1963
1964
1965
    def _forward_prefill(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
1966
        k_scale: torch.Tensor,
1967
1968
        output: torch.Tensor,
    ) -> None:
1969
        # TODO (zyongye): Prefill function here
1970
        assert attn_metadata.prefill is not None
1971
        assert self.dcp_world_size is not None
1972
1973

        has_context = attn_metadata.prefill.chunked_context is not None
1974
1975
1976
1977
        kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
            -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
        )
        k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1978
1979
1980

        k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

1981
        output_prefill = self._run_prefill_new_tokens(
1982
            prefill=attn_metadata.prefill,
1983
1984
            q=q,
            k=k,
1985
            v=v,
1986
1987
1988
1989
            return_softmax_lse=has_context,
        )

        if has_context:
1990
            suffix_output, suffix_lse = output_prefill
1991
            if self.dcp_world_size > 1:
1992
                context_output, context_lse = (
1993
                    self._context_parallel_compute_prefill_context(
1994
1995
1996
1997
1998
1999
2000
                        q,
                        kv_c_and_k_pe_cache,
                        attn_metadata,
                        k_scale=None,
                        dcp_world_size=self.dcp_world_size,
                    )
                )
2001
            else:
2002
2003
2004
                context_output, context_lse = self._compute_prefill_context(
                    q, kv_c_and_k_pe_cache, attn_metadata, k_scale
                )
2005

2006
2007
2008
2009
2010
2011
            # unpad if necessary
            if self._pad_v:
                context_output = context_output[..., : v.shape[-1]]
                suffix_output = suffix_output[..., : v.shape[-1]]

            output = output.view(-1, self.num_heads, self.v_head_dim)
2012
2013
2014
2015
2016
2017
2018
            merge_attn_states(
                output=output,
                prefix_output=context_output,
                prefix_lse=context_lse,
                suffix_output=suffix_output,
                suffix_lse=suffix_lse,
            )
2019
2020
2021
        else:
            output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
            output.copy_(output_prefill)
2022
2023
2024
2025

    @abstractmethod
    def _forward_decode(
        self,
2026
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
2027
        kv_c_and_k_pe_cache: torch.Tensor,
2028
        attn_metadata: M,
2029
        layer: AttentionLayer,
2030
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
2031
2032
2033
2034
2035
        raise NotImplementedError

    def forward(
        self,
        layer: AttentionLayer,
2036
        q: torch.Tensor,
2037
2038
2039
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
2040
        attn_metadata: M,
2041
2042
2043
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
2044
2045
2046
    ) -> torch.Tensor:
        assert output is not None, "Output tensor must be provided."

2047
        if output_scale is not None or output_block_scale is not None:
2048
            raise NotImplementedError(
2049
2050
                "fused output quantization is not yet supported for MLACommonImpl"
            )
2051

2052
        if attn_metadata is None:
2053
2054
2055
2056
            # During the profile run try to simulate to worse case output size
            # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
            # since this can be large
            _ = torch.empty(
2057
2058
2059
2060
2061
                (
                    self.chunked_prefill_workspace_size,
                    self.num_heads,
                    self.qk_nope_head_dim + self.v_head_dim,
                ),
2062
2063
2064
2065
                device=k_c_normed.device,
                dtype=k_c_normed.dtype,
            )

2066
2067
2068
2069
            # The zero fill is required when used with DP + EP
            # to ensure all ranks within a DP group compute the
            # same expert outputs.
            return output.fill_(0)
2070

2071
2072
2073
        if self.dcp_world_size is None:
            self.dcp_world_size = get_dcp_group().world_size

2074
2075
        fp8_attention = self.kv_cache_dtype.startswith("fp8")

2076
2077
2078
2079
2080
        num_actual_toks = attn_metadata.num_actual_tokens

        # Inputs and outputs may be padded for CUDA graphs
        output_padded = output
        output = output[:num_actual_toks, ...]
2081
        q = q[:num_actual_toks, ...]
2082
2083
2084
        k_c_normed = k_c_normed[:num_actual_toks, ...]
        k_pe = k_pe[:num_actual_toks, ...]

2085
2086
2087
2088
2089
        assert (
            attn_metadata.num_decodes is not None
            and attn_metadata.num_prefills is not None
            and attn_metadata.num_decode_tokens is not None
        )
2090
2091
2092
2093
2094

        has_decode = attn_metadata.num_decodes > 0
        has_prefill = attn_metadata.num_prefills > 0
        num_decode_tokens = attn_metadata.num_decode_tokens

2095
        decode_q = q[:num_decode_tokens]
2096

2097
        prefill_q = q[num_decode_tokens:]
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
        prefill_k_pe = k_pe[num_decode_tokens:]
        prefill_k_c_normed = k_c_normed[num_decode_tokens:]

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                k_c_normed,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=self.kv_cache_dtype,
                scale=layer._k_scale,
            )

2112
2113
2114
        if fp8_attention:
            kv_cache = kv_cache.view(current_platform.fp8_dtype())

2115
        if has_prefill:
2116
            self._forward_prefill(
2117
2118
2119
2120
2121
2122
                prefill_q,
                prefill_k_c_normed,
                prefill_k_pe,
                kv_cache,
                attn_metadata,
                layer._k_scale,
2123
                output=output[num_decode_tokens:],
2124
            )
2125
2126

        if has_decode:
2127
            assert attn_metadata.decode is not None
2128

2129
            decode_q_nope, decode_q_pe = decode_q.split(
2130
2131
                [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
            )
2132

2133
2134
            # Convert from (B, N, P) to (N, B, P)
            decode_q_nope = decode_q_nope.transpose(0, 1)
2135

2136
2137
            if self.q_pad_num_heads is not None:
                B, N, L = decode_q_pe.shape
2138
                decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
2139
2140
2141
2142
                decode_pe_padded.resize_((B, N, L))
                decode_pe_padded.copy_(decode_q_pe)
                decode_q_pe = decode_pe_padded

2143
            if self.is_aiter_triton_fp8_bmm_enabled:
2144
                # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
2145
                decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
2146
2147
2148
2149
2150
2151
                    decode_q_nope,
                    self.W_K,
                    self.W_K_scale,
                    group_size=128,
                    transpose_bm=True,
                )
2152
            else:
2153
2154
2155
                # Pads the head_dim if necessary (for the underlying kernel)
                N, B, P = decode_q_nope.shape
                _, _, L = self.W_UK_T.shape
2156

2157
2158
                if self.q_pad_num_heads is not None:
                    decode_ql_nope = decode_q_nope.new_empty(
2159
2160
                        (self.q_pad_num_heads, B, L)
                    )
2161
2162
2163
2164
                    decode_ql_nope.resize_((N, B, L))
                else:
                    decode_ql_nope = decode_q_nope.new_empty((N, B, L))

2165
                # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
2166
                torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
2167

2168
2169
                # Convert from (N, B, L) to (B, N, L)
                decode_ql_nope = decode_ql_nope.transpose(0, 1)
2170

2171
2172
2173
            if fp8_attention:
                ql_nope_shape = decode_ql_nope.shape
                q_pe_shape = decode_q_pe.shape
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
                assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
                assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
                decode_q_shape = (
                    ql_nope_shape[0],
                    ql_nope_shape[1],
                    ql_nope_shape[2] + q_pe_shape[2],
                )
                # Using empty and copy since torch.cat introduces significant overhead.
                decode_q0 = torch.empty(
                    decode_q_shape,
                    device=decode_ql_nope.device,
                    dtype=decode_ql_nope.dtype,
2186
                )
2187
2188
                decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope)
                decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe)
2189

2190
2191
2192
2193
2194
2195
2196
                decode_q, _ = ops.scaled_fp8_quant(
                    decode_q0.view(decode_q_shape[0], -1),
                    layer._q_scale,
                )
                decode_q = decode_q.view(decode_q_shape)
            else:
                decode_q = (decode_ql_nope, decode_q_pe)
2197
2198
2199
2200
2201
2202
2203
2204
            if self.dcp_world_size > 1:
                assert not fp8_attention, "DCP not support fp8 kvcache now."
                # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
                decode_q = torch.cat(decode_q, dim=-1)
                # decode_q do allgather in head dim.
                decode_q = get_dcp_group().all_gather(decode_q, dim=1)

            # call decode attn
2205
2206
2207
            attn_out, lse = self._forward_decode(
                decode_q, kv_cache, attn_metadata, layer
            )
2208

2209
            # correct dcp attn_out with lse.
2210
            if self.dcp_world_size > 1:
2211
2212
2213
2214
                attn_out = cp_lse_ag_out_rs(
                    attn_out,
                    lse,
                    get_dcp_group(),
2215
                    is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
2216
                )
2217
2218

            # v_up projection
2219
            self._v_up_proj(attn_out, out=output[:num_decode_tokens])
2220
        return output_padded