flashinfer.py 67.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer with FlashInfer."""
4

5
from dataclasses import dataclass
6
from functools import partial
7
from typing import ClassVar
8

9
import numpy as np
10
import torch
11
12
13
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
14
    BatchPrefillWithRaggedKVCacheWrapper,
15
16
    MultiLevelCascadeAttentionWrapper,
)
17
from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache
18
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
19
from flashinfer.utils import FP4Tensor
20
from typing_extensions import override
21

22
from vllm import envs
23
24
25
26
27
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_current_vllm_config_or_none,
)
28
from vllm.config.cache import CacheDType
29
from vllm.distributed.parallel_state import get_dcp_group
30
from vllm.logger import init_logger
31
from vllm.model_executor.layers.batch_invariant import (
32
    vllm_is_batch_invariant,
33
)
34
from vllm.model_executor.layers.quantization.utils.quant_utils import (
35
36
    QuantKey,
    kFp8StaticTensorSym,
37
    kNvfp4Dynamic,
38
)
39
from vllm.platforms import current_platform
40
from vllm.platforms.interface import DeviceCapability
41
from vllm.triton_utils import tl, triton
42
43
44
45
from vllm.utils.flashinfer import (
    can_use_trtllm_attention,
    use_trtllm_attention,
)
46
from vllm.utils.math_utils import cdiv
47
from vllm.utils.platform_utils import is_pin_memory_available
48
from vllm.utils.torch_utils import is_strictly_contiguous
49
50
from vllm.v1.attention.backend import (
    AttentionBackend,
51
    AttentionCGSupport,
52
    AttentionImpl,
53
    AttentionMetadataBuilder,
54
    AttentionType,
55
    CommonAttentionMetadata,
56
57
    MultipleOf,
)
58
from vllm.v1.attention.backends.utils import (
59
    KVCacheLayoutType,
60
    get_dcp_local_seq_lens,
61
62
63
64
65
    get_kv_cache_layout,
    get_per_layer_parameters,
    infer_global_hyperparameters,
    split_decodes_and_prefills,
)
66
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
67
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
68
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
69
from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs
70
from vllm.v1.utils import CpuGpuBuffer
71

72
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
73

74
FP8_DTYPE = current_platform.fp8_dtype()
75
FP4_DTYPE = torch.uint8
76

77
78
logger = init_logger(__name__)

79
80
81
82
83
84
85
trtllm_gen_workspace_buffer = None


def _get_trtllm_gen_workspace_buffer():
    global trtllm_gen_workspace_buffer
    if trtllm_gen_workspace_buffer is None:
        trtllm_gen_workspace_buffer = torch.zeros(
86
            envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
87
        )
88
89
    return trtllm_gen_workspace_buffer

90

91
92
93
94
95
96
97
98
99
100
101
102
103
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
    kv_cache_ptr,
    block_tables_prefill_ptr,
    block_table_stride,
    mock_kv_cache_ptr,
    k_scale_ptr,
    v_scale_ptr,
    K_CACHE_STRIDE: tl.constexpr,
    KV_CACHE_STRIDE: tl.constexpr,
):
    batch_idx = tl.program_id(0).to(tl.int64)
    mock_block_table_idx = tl.program_id(1).to(tl.int64)
104
105
106
    orig_page_num = tl.load(
        block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx
    ).to(tl.int64)
107
108
109
110
111
112
113
114
115
    if orig_page_num <= 0:
        return
    dequant_dtype = mock_kv_cache_ptr.dtype.element_ty

    # Dequantize K
    k_scale_val = tl.load(k_scale_ptr)
    offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
116
117
118
    mock_cache_offset = (
        batch_idx * block_table_stride + mock_block_table_idx + 1
    ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
119
120
121
122
123
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)

    # Dequantize V
    v_scale_val = tl.load(v_scale_ptr)
124
125
126
    offset = (
        orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    )
127
128
129
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
    mock_cache_offset = (
130
131
132
133
        (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE
        + K_CACHE_STRIDE
        + tl.arange(0, K_CACHE_STRIDE)
    )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)


def trtllm_prefill_attn_kvfp8_dequant(
    kv_cache: torch.Tensor,
    block_tables_prefill: torch.Tensor,
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    dequant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
    batch_size, num_of_page_per_token = block_tables_prefill.shape
    s = kv_cache.shape
    assert s[1] == 2
    assert dequant_dtype in (torch.bfloat16, torch.float16)
    k_cache_stride = s[2] * s[3] * s[4]
    kv_cache_stride = k_cache_stride * s[1]
    new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
    # mock kv cache contains just the pages needed by this prefill
153
    mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    # we simply sequentially index the pages needed by this prefill
    mock_block_table = torch.arange(
        start=1,
        end=batch_size * num_of_page_per_token + 1,
        dtype=torch.int32,
        device=block_tables_prefill.device,
    ).reshape(batch_size, num_of_page_per_token)
    grid = (batch_size, num_of_page_per_token)
    _trtllm_prefill_attn_kvfp8_dequant[grid](
        kv_cache,
        block_tables_prefill,
        num_of_page_per_token,
        mock_kv_cache,
        k_scale,
        v_scale,
        k_cache_stride,
        kv_cache_stride,
    )
    return mock_kv_cache, mock_block_table

174

175
176
177
178
class BatchDCPPrefillWrapper:
    def __init__(
        self,
        workspace_buffer: torch.Tensor | None = None,
179
        dcp_a2a: bool = False,
180
    ):
181
182
183
184
        if dcp_a2a:
            self._dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
        else:
            self._dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        self._context = BatchPrefillWithPagedKVCacheWrapper(
            workspace_buffer, get_kv_cache_layout()
        )
        self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(
            workspace_buffer, get_kv_cache_layout()
        )

    def plan(
        self,
        qo_indptr_cpu: torch.Tensor,
        paged_kv_indptr_cpu: torch.Tensor,
        paged_kv_indices: torch.Tensor,
        paged_kv_last_page_len_cpu: torch.Tensor,
        page_size: int,
        num_qo_heads: int,
        dcp_world_size: int,
        num_kv_heads: int,
        head_dim: int,
        sm_scale: float,
        window_left: int,
        logits_soft_cap: float | None,
        q_data_type: torch.dtype,
        kv_cache_dtype: torch.dtype,
        prefill_fixed_split_size: int,
        disable_split_kv: bool,
    ):
        """Plan the prefill operation with given parameters."""
        self._context.plan(
213
214
215
216
217
218
219
220
            qo_indptr=qo_indptr_cpu,
            paged_kv_indptr=paged_kv_indptr_cpu,
            paged_kv_indices=paged_kv_indices,
            paged_kv_last_page_len=paged_kv_last_page_len_cpu,
            num_qo_heads=num_qo_heads * dcp_world_size,
            num_kv_heads=num_kv_heads,
            head_dim_qk=head_dim,
            page_size=page_size,
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
            causal=False,  # This is context run
            sm_scale=sm_scale,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
            kv_data_type=kv_cache_dtype,
            fixed_split_size=prefill_fixed_split_size,
            disable_split_kv=disable_split_kv,
        )
        self._new_tokens.plan(
            qo_indptr=qo_indptr_cpu,
            kv_indptr=qo_indptr_cpu,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim_qk=head_dim,
            head_dim_vo=head_dim,
            causal=True,  # This is newtokens run
            sm_scale=sm_scale,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
        )

    def run(
        self,
        layer: torch.nn.Module,
        prefill_query: torch.Tensor,
        kv_cache_permute: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        out: torch.Tensor,
    ):
        prefill_query_across_dcp = get_dcp_group().all_gather(
            prefill_query.contiguous(), dim=1
        )
        output_context_tmp, lse_context_tmp = self._context.run(
            prefill_query_across_dcp,
            kv_cache_permute,
            k_scale=layer._k_scale_float,
            v_scale=layer._v_scale_float,
            return_lse=True,
        )
263
        output_context, lse_context = self._dcp_combine(
264
265
266
267
            output_context_tmp,
            lse_context_tmp,
            get_dcp_group(),
            return_lse=True,
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        )
        lse_context = lse_context.transpose(0, 1).contiguous()

        output_query, lse_query = self._new_tokens.run(
            prefill_query,
            key,
            value,
            return_lse=True,
        )
        lse_query = lse_query.transpose(0, 1).contiguous()

        merge_attn_states(
            out,
            output_context,
            lse_context,
            output_query,
            lse_query,
        )
        return out


289
class FlashInferBackend(AttentionBackend):
290
    accept_output_buffer: bool = True
291
292
293
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
294
        "bfloat16",
295
296
297
298
        "fp8",
        "fp8_e4m3",
        "fp8_e5m2",
    ]
299

300
301
302
303
304
305
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        # Note: Not sure for all platforms, but on Blackwell,
        # only support a page size of 16, 32, 64.
        return [16, 32, 64]

306
307
    @staticmethod
    def get_name() -> str:
308
        return "FLASHINFER"
309
310

    @staticmethod
311
    def get_impl_cls() -> type["FlashInferImpl"]:
312
313
314
        return FlashInferImpl

    @staticmethod
315
    def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
316
317
318
319
320
321
322
323
        return FlashInferMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
324
        cache_dtype_str: str = "auto",
325
326
327
    ) -> tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

328
    @staticmethod
329
330
331
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
332
333
334
        # `stride_order` indicates the permutation that gets us from
        # `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
335
336
337
338
        if cache_layout == "NHD" and include_num_layers_dimension:
            # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
            return (1, 0, 2, 3, 4, 5)
        elif cache_layout == "NHD":
339
            stride_order = (0, 1, 2, 3, 4)
340
341
342
        elif cache_layout == "HND" and include_num_layers_dimension:
            # (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
            return (1, 2, 4, 0, 3, 5)
343
344
345
346
347
348
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

349
350
351
352
353
354
355
356
357
    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

358
359
360
361
362
363
364
365
366
367
368
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        return [64, 128, 256]

    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability(
            12, 1
        )

369
370
371
372
373
374
375
376
    @classmethod
    def supports_sink(cls) -> bool:
        """FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
        from vllm.utils.flashinfer import (
            force_use_trtllm_attention,
            supports_trtllm_attention,
        )

377
378
        # Respect explicit disable flag (e.g.,
        # --attention-config.use_trtllm_attention=0)
379
380
381
382
383
384
        if force_use_trtllm_attention() is False:
            return False

        # Check if TRTLLM is supported on this platform
        return supports_trtllm_attention()

385
386
387
388
389
390
391
    @classmethod
    def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
        capability = current_platform.get_device_capability()
        if capability is not None and capability.major == 10:
            return "HND"
        return None

392
393
    forward_includes_kv_cache_update: bool = False

394
395

@dataclass
396
397
class FIPrefill:
    """Metadata for the native FlashInfer prefill pathway (non-TRTLLM)."""
398

399
    wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper
400
401


402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
@dataclass
class FIDecode:
    """Metadata for the native FlashInfer decode pathway (non-TRTLLM)."""

    wrapper: BatchDecodeWithPagedKVCacheWrapper


@dataclass
class TRTLLMPrefill:
    """Metadata for the TRTLLM prefill pathway."""

    block_tables: torch.Tensor
    """
    The slice of the block table tensor corresponding *only* to prefill requests.
    Shape: [num_prefills, max_num_blocks_per_seq]
    """

    seq_lens: torch.Tensor
    """
    The slice of the sequence lengths tensor corresponding *only* to prefill requests.
    Shape: [num_prefills]
    """

    cum_seq_lens_q: torch.Tensor
    cum_seq_lens_kv: torch.Tensor

428
    max_q_len: int
429
    """
430
    The maximum query length *among prefill requests*.
431
432
    """

433
    max_seq_len: int
434
435
436
437
438
439
440
441
442
443
444
445
446
    """The maximum sequence length for KV Cache."""


@dataclass
class TRTLLMDecode:
    """Metadata for the TRTLLM decode pathway."""

    block_tables: torch.Tensor
    """
    The slice of the block table tensor corresponding *only* to decode requests.
    Shape: [num_decodes, max_num_blocks_per_seq]
    """

447
    seq_lens: torch.Tensor
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    """
    The slice of the sequence lengths tensor corresponding *only* to decode requests.
    Shape: [num_decodes]
    """

    max_seq_len: int
    """The maximum sequence length for KV Cache."""


@dataclass
class FlashInferMetadata:
    num_actual_tokens: int
    """Total number of tokens in the batch (excluding padding)."""

    slot_mapping: torch.Tensor
    """Tensor for writing K/V to the cache. Shape: [num_actual_tokens]"""

    q_data_type: torch.dtype
466

467
468
469
470
471
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

472
473
474
475
476
    prefill: FIPrefill | TRTLLMPrefill | None
    """
    Holds the metadata for the prefill portion of the batch.
    Will be `None` if `num_prefill_tokens == 0`.
    """
477

478
479
480
481
482
    decode: FIDecode | TRTLLMDecode | None
    """
    Holds the metadata for the decode portion of the batch.
    Will be `None` if `num_decode_tokens == 0`.
    """
483

484
485
486
487
488
489
490
491
492
    # --- Special Case: Cascade Attention ---

    use_cascade: bool
    """
    If True, the entire batch is a cascade attention call, and the
    `prefill` and `decode` fields will both be None.
    """

    cascade_wrapper: MultiLevelCascadeAttentionWrapper | None
493

494

495
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
496
    reorder_batch_threshold: int = 1
497

498
499
500
501
502
503
504
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
505
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
506
        self.cache_config = vllm_config.cache_config
507
        self.model_config = vllm_config.model_config
508
        self.attention_config = vllm_config.attention_config
509
        self._workspace_buffer = None
510
511
512
        self._prefill_wrapper: (
            BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
        ) = None  # Wrapper for prefill/append
513
514
        self._decode_wrapper = None  # Wrapper for decode (general shape)

515
        if vllm_is_batch_invariant():
516
517
518
519
520
521
522
523
            self.decode_fixed_split_size = 2048
            self.prefill_fixed_split_size = 4096
            self.disable_split_kv = True
        else:
            self.decode_fixed_split_size = -1
            self.prefill_fixed_split_size = -1
            self.disable_split_kv = False

524
        self.compilation_config = vllm_config.compilation_config
525
526
527
        max_num_pages_per_req = cdiv(
            self.model_config.max_model_len, self.kv_cache_spec.block_size
        )
528
529
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        max_num_pages = max_num_reqs * max_num_pages_per_req
530
531
532
533
534
535
        speculative_config = vllm_config.speculative_config
        num_spec_tokens = (
            speculative_config.num_speculative_tokens
            if speculative_config is not None
            else 0
        )
536
537
538
        self.enable_cuda_graph = (
            self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
        )
539
540
541
542
        if self.enable_cuda_graph:
            # For full cudagraph capture, one `decode_wrapper` for each batch
            # size is needed for FlashInfer.
            self._decode_wrappers_cudagraph: dict[
543
544
                int, BatchDecodeWithPagedKVCacheWrapper
            ] = {}
545
546
547
548
549
550
            self._decode_cudagraph_max_bs = (1 + num_spec_tokens) * max_num_reqs
            if self.compilation_config.max_cudagraph_capture_size is not None:
                self._decode_cudagraph_max_bs = min(
                    self._decode_cudagraph_max_bs,
                    self.compilation_config.max_cudagraph_capture_size,
                )
551
552
553
554
555
556
557
558
559
560
561
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
            self.dcp_kv_cache_interleave_size = (
                vllm_config.parallel_config.dcp_kv_cache_interleave_size
            )
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
            self.dcp_kv_cache_interleave_size = 1
562
        self.use_dcp = self.dcp_world_size > 1
563
564
565
        self.dcp_a2a = (
            self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a"
        )
566

567
568
        self.num_qo_heads = self.model_config.get_num_attention_heads(
            self.vllm_config.parallel_config
569
        )
570

571
572
573
574
575
576
        self.num_kv_heads = self.kv_cache_spec.num_kv_heads
        self.head_dim = self.kv_cache_spec.head_size
        self.page_size = self.kv_cache_spec.block_size

        self.cache_dtype = self.cache_config.cache_dtype
        if self.cache_dtype.startswith("fp8"):
577
578
579
            self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.cache_dtype
            )
580
        else:
581
            assert self.kv_cache_spec.dtype == self.model_config.dtype
582
            self.kv_cache_dtype = self.kv_cache_spec.dtype
583

584
        # Use model dtype as q dtype when TRTLLM attn is not supported, or
585
586
        # --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
        # try to use fp8 q if kv cache is fp8, and will fall back to model dtype
587
        # if TRTLLM attention kernel is not used when building attn metadata
588
        can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
589
590
591
592
        if (
            can_use_trtllm
            and not vllm_config.attention_config.disable_flashinfer_q_quantization
        ):
593
594
595
            self.q_data_type = self.kv_cache_dtype
        else:
            self.q_data_type = self.model_config.dtype
596

597
598
599
        # Prefer TRTLLM attention for decoding in all cases.
        # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
        self.use_trtllm_decode_attention = can_use_trtllm
600
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
601

602
603
604
        self._cascade_wrapper = None  # Wrapper for cascade attention

        # Global hyperparameters shared by all attention layers
605
        # TODO: discard this for trtllm-gen backend
606
        self.global_hyperparameters = infer_global_hyperparameters(
607
608
            get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)
        )
609
610
611
612
        self.sm_scale = self.global_hyperparameters.sm_scale
        self.window_left = self.global_hyperparameters.window_left
        self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
        self.has_sinks = self.global_hyperparameters.has_sinks
613
        if self.has_sinks and not can_use_trtllm:
614
615
616
            raise NotImplementedError(
                "FlashInfer backend currently does not support attention "
                "sinks, please use trtllm on blackwell or flash attention on "
617
618
                "earlier GPUs."
            )
619
        # Preparing persistent buffers
620
621
622
623
624
625
        # Since we do not have explicit synchronization in ModelRunnerV2, we do not pin
        # reused CPU buffers to avoid a race condition between step N async copies to
        # GPU and step N+1 buffer updates.
        self.pin_memory = (
            not envs.VLLM_USE_V2_MODEL_RUNNER and is_pin_memory_available()
        )
626
627
628
629
630
631
        self.paged_kv_indptr = self._make_buffer(max_num_reqs + 1)
        self.paged_kv_indptr_cpu_buffer = torch.zeros_like(
            self.paged_kv_indptr.cpu, pin_memory=self.pin_memory
        )  # Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode
        self.paged_kv_indices = self._make_buffer(max_num_pages)
        self.paged_kv_last_page_len = self._make_buffer(max_num_reqs)
632

633
634
635
636
637
638
639
640
641
642
643
644
    def _make_buffer(
        self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32
    ) -> CpuGpuBuffer:
        return CpuGpuBuffer(
            *size,
            dtype=dtype,
            device=self.device,
            pin_memory=self.pin_memory,
            with_numpy=True,
        )

    @override  # type: ignore[misc]
645
646
647
648
649
650
    @classmethod
    def get_cudagraph_support(
        cls: type["FlashInferMetadataBuilder"],
        vllm_config: VllmConfig,
        kv_cache_spec: AttentionSpec,
    ) -> AttentionCGSupport:
651
652
653
654
655
656
657
658
659
660
661
662
663
        """Get the cudagraph support level for FlashInfer attention.

        This depends on whether we can use TRTLLM attention for decodes, since we can
        only do UNIFORM_SINGLE_TOKEN_DECODE if it is unavailable.
        To check this, we must call can_use_trtllm_attention with the number of KV
        heads from the kv_cache_spec. We check all available KV cache specs and
        only return UNIFORM_BATCH if all of them support TRTLLM attention.
        """
        # For UniformTypeKVCacheSpecs, check all contained specs
        kv_specs = (
            kv_cache_spec.kv_cache_specs.values()
            if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs)
            else [kv_cache_spec]
664
        )
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
        num_qo_heads = vllm_config.model_config.get_num_attention_heads(
            vllm_config.parallel_config
        )
        has_trtllm_support: bool = len(kv_specs) > 0
        for spec in kv_specs:
            if not isinstance(spec, AttentionSpec):
                # FlashInfer only applies to attention, so we don't consider other types
                # of KV spec (e.g. Mamba) here. This is mostly for type checking.
                continue
            if not can_use_trtllm_attention(
                num_qo_heads=num_qo_heads,
                num_kv_heads=spec.num_kv_heads,
            ):
                has_trtllm_support = False
                break

681
682
683
684
685
        if has_trtllm_support:
            return AttentionCGSupport.UNIFORM_BATCH
        else:
            return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

686
687
    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
688
            buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
689
            if vllm_is_batch_invariant():
690
                buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
691
            self._workspace_buffer = torch.zeros(
692
                buffer_size, dtype=torch.uint8, device=self.device
693
            )
694
695
        return self._workspace_buffer

Woosuk Kwon's avatar
Woosuk Kwon committed
696
697
698
    def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
        self._workspace_buffer = workspace_buffer

699
700
701
    def _get_prefill_wrapper(
        self,
    ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
702
        if self._prefill_wrapper is None:
703
            if self.use_dcp:
704
705
                self._prefill_wrapper = BatchDCPPrefillWrapper(
                    workspace_buffer=self._get_workspace_buffer(),
706
                    dcp_a2a=self.dcp_a2a,
707
708
709
710
711
712
                )
            else:
                self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                    self._get_workspace_buffer(), get_kv_cache_layout()
                )
        assert self._prefill_wrapper is not None
713
714
        return self._prefill_wrapper

715
    def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
716
        if use_cudagraph:
717
            decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None)
718
719
720
721
722
        else:
            decode_wrapper = self._decode_wrapper

        if decode_wrapper is None:
            if use_cudagraph:
723
724
725
                paged_kv_indptr = self.paged_kv_indptr.gpu[: batch_size + 1]
                paged_kv_indices = self.paged_kv_indices.gpu
                paged_kv_last_page_len = self.paged_kv_last_page_len.gpu[:batch_size]
726
727
728
729
730
            else:
                paged_kv_indptr = None
                paged_kv_indices = None
                paged_kv_last_page_len = None
            decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
731
                self._get_workspace_buffer(),
732
                get_kv_cache_layout(),
733
734
735
736
                use_cuda_graph=use_cudagraph,
                paged_kv_indptr_buffer=paged_kv_indptr,
                paged_kv_indices_buffer=paged_kv_indices,
                paged_kv_last_page_len_buffer=paged_kv_last_page_len,
737
                # Tensor cores are enabled by default because the perf would be
co63oc's avatar
co63oc committed
738
                # at least as good as cuda cores for all attention ops in latest
739
740
741
                # gpus.
                use_tensor_cores=True,
            )
742
743
744
745
746
747
748
749

            # save the decode wrapper
            if use_cudagraph:
                self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
            else:
                self._decode_wrapper = decode_wrapper

        return decode_wrapper
750
751
752
753

    def _get_cascade_wrapper(self):
        if self._cascade_wrapper is None:
            self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
754
755
                2, self._get_workspace_buffer(), get_kv_cache_layout()
            )
756
757
        return self._cascade_wrapper

758
    def _compute_flashinfer_kv_metadata(
759
        self,
760
761
762
763
764
765
766
767
768
        num_blocks_np: np.ndarray,
        seq_lens_np: np.ndarray,
        block_table_tensor: torch.Tensor,
        num_reqs: int,
        page_size: int,
    ) -> torch.Tensor:
        """
        Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer
        attention.
769

770
771
        Results are stored in self.paged_kv_indptr,
        self.paged_kv_indices, self.paged_kv_last_page_len buffers.
772

773
774
        Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages].
        """
775
776
777
778
        # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
        np.cumsum(
            num_blocks_np,
            dtype=np.int32,
779
            out=self.paged_kv_indptr.np[1 : num_reqs + 1],
780
        )
781
782
783
        # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
        # after this line (e.g., for cuda graphs), we need to copy the data to
        # self.paged_kv_indptr_buffer to avoid race condition.
784
        self.paged_kv_indptr_cpu_buffer[: num_reqs + 1] = self.paged_kv_indptr.cpu[
785
786
            : num_reqs + 1
        ]
787
        paged_kv_indptr = self.paged_kv_indptr.gpu[: num_reqs + 1]
788
        paged_kv_indptr.copy_(
789
            self.paged_kv_indptr_cpu_buffer[: num_reqs + 1], non_blocking=True
790
        )
791

792
        # write self.paged_kv_indices inplace
793
794
        num_actual_pages = self.paged_kv_indptr.np[num_reqs]
        paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages]
795
        _copy_page_indices_kernel[(num_reqs,)](
796
797
798
799
800
801
            paged_kv_indices,
            block_table_tensor,
            block_table_tensor.stride(0),
            paged_kv_indptr,
            BLOCK_SIZE=1024,
        )
802

803
        # write self.paged_kv_last_page_len_cpu inplace
804
        paged_kv_last_page_len_np = seq_lens_np % page_size
805
        self.paged_kv_last_page_len.np[:num_reqs] = np.where(
806
            (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
807
808
809
            page_size,
            paged_kv_last_page_len_np,
        )
810
811
812
        self.paged_kv_last_page_len.gpu[:num_reqs].copy_(
            self.paged_kv_last_page_len.cpu[:num_reqs], non_blocking=True
        )
813
        return paged_kv_indices
814

815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashInferMetadata:
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold,
                require_uniform=True,
            )
        )

        page_size = self.page_size
        max_seq_len = common_attn_metadata.max_seq_len
        seq_lens = common_attn_metadata.seq_lens
        block_table_tensor = common_attn_metadata.block_table_tensor
        qo_indptr = common_attn_metadata.query_start_loc
        qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu

        # Step 1: Decide which dispatch modes to use:
        # - Cascade attention (distinct mode)
        # - Prefill (FI native or TRTLLM)
        # - Decode (FI native or TRTLLM)
        use_cascade = common_prefix_len > 0
843
        uses_spec_reorder = self.reorder_batch_threshold > 1
844
845
846
847
848
        prefill_use_trtllm = use_trtllm_attention(
            self.num_qo_heads,
            self.num_kv_heads,
            num_prefill_tokens,
            max_seq_len,
849
            self.dcp_world_size,
850
851
852
            self.cache_dtype,
            self.q_data_type,
            is_prefill=True,
853
            force_use_trtllm=self.attention_config.use_trtllm_attention,
854
855
856
            has_sinks=self.has_sinks,
            has_spec=uses_spec_reorder,
        )
857
858
859
        decode_use_trtllm = (
            self.use_trtllm_decode_attention and self.dcp_world_size <= 1
        )
860

861
862
863
864
865
866
867
868
        all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and (
            num_decodes == 0 or decode_use_trtllm
        )
        is_only_trtllm_decode = num_prefills == 0 and (
            num_decodes > 0 and decode_use_trtllm
        )

        if not all_uses_trtllm:
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
            if self.has_sinks:
                raise NotImplementedError(
                    "FlashInfer backend currently does not support attention "
                    "sinks, please use trtllm on blackwell or flash attention "
                    "on earlier GPUs."
                )

            if not self.global_hyperparameters.has_same_window_lefts:
                raise ValueError(
                    "Window left is not the same for all layers. "
                    "One potential fix is to set disable_sliding_window=True"
                )

            assert self.global_hyperparameters.has_same_all_params, (
                "FlashInfer backend currently only supports models in which "
                "all layers share the same values for the following "
                "hyperparameters: `window_left`, `logits_soft_cap`, "
                "`sm_scale`."
            )

            # The q quantization is not supported for non-trtllm attention,
            # fall back to model dtype.
891
892
            self.q_data_type = self.model_config.dtype

893
894
895
        # Step 2: Initialize the output metadata
        # Leave prefill/decode/cascade_wrapper empty, to be populated
        # case by case depending on the batch contents and backend selection.
896
897
        attn_metadata = FlashInferMetadata(
            num_actual_tokens=num_actual_tokens,
898
            slot_mapping=common_attn_metadata.slot_mapping,
899
            q_data_type=self.q_data_type,
900
901
902
903
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
904
            use_cascade=use_cascade,
905
906
907
            prefill=None,
            decode=None,
            cascade_wrapper=None,
908
909
        )

910
911
912
        # Guard access to seq_lens_cpu, which may not always be needed
        # and can be expensive to retrieve in async mode.
        needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
913
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
        seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
        num_blocks_np = (
            (seq_lens_np + (page_size - 1)) // page_size
            if seq_lens_np is not None
            else None
        )

        # Adjust seq_lens_cpu for DCP
        if self.use_dcp:
            assert seq_lens_cpu is not None
            if num_prefills > 0:
                qo_indptr_prefill_cpu = (
                    qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
                )
                query_lens_prefill_cpu = (
                    qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
                )
                seq_lens_cpu[num_decodes:] = (
                    seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
                )

            seq_lens_cpu = get_dcp_local_seq_lens(
                seq_lens_cpu,
                self.dcp_world_size,
                self.dcp_rank,
                self.dcp_kv_cache_interleave_size,
            )

        # Adjust num_block_np for cascade attention
        if use_cascade:
            assert num_blocks_np is not None
            assert common_prefix_len % page_size == 0
            num_common_kv_blocks = common_prefix_len // page_size
            num_blocks_np -= num_common_kv_blocks

        # Compute paged_kv_indices if necessary
        needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode
        if needs_paged_kv_indices:
            assert num_blocks_np is not None
            assert seq_lens_np is not None
            paged_kv_indices = self._compute_flashinfer_kv_metadata(
                num_blocks_np,
                seq_lens_np,
                block_table_tensor,
                num_reqs,
                page_size,
            )
        else:
            paged_kv_indices = None

        # Early-out for cascade attention
        if use_cascade:
966
            assert num_blocks_np is not None
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
            # Grab the blocks of the shared prefix from the first request.
            num_common_kv_blocks = common_prefix_len // page_size

            # Create CPU versions directly for cascade (no GPU versions needed)
            shared_qo_indptr_cpu = torch.tensor(
                [0, num_actual_tokens], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indptr_cpu = torch.tensor(
                [0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
            shared_kv_last_page_len_cpu = torch.tensor(
                [page_size], dtype=torch.int32, device="cpu"
            )

            # Remove the blocks of the shared prefix from all requests.
            block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
            num_blocks_np -= num_common_kv_blocks

            assert paged_kv_indices is not None
            paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[: 1 + num_reqs]
            paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs]
989

990
991
            attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
            attn_metadata.cascade_wrapper.plan(
992
993
994
995
996
997
998
999
1000
1001
1002
                qo_indptr_arr=[shared_qo_indptr_cpu, qo_indptr_cpu],
                paged_kv_indptr_arr=[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
                paged_kv_indices_arr=[shared_kv_page_indices_cpu, paged_kv_indices],
                paged_kv_last_page_len=[
                    shared_kv_last_page_len_cpu,
                    paged_kv_last_page_len_cpu,
                ],
                num_qo_heads=self.num_qo_heads,
                num_kv_heads=self.num_kv_heads,
                head_dim=self.head_dim,
                page_size=self.page_size,
1003
                causal=True,
1004
1005
1006
                sm_scale=self.sm_scale,
                window_left=self.window_left,
                logits_soft_cap=self.logits_soft_cap,
1007
1008
1009
                q_data_type=self.q_data_type,
                kv_data_type=self.kv_cache_dtype,
            )
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
            return attn_metadata

        # Step 3: Handle prefill and decode pathways case by case
        ## PREFILL PATHWAY
        if num_prefills > 0:
            # Slices for shared prefill metadata
            prefill_start = num_decodes
            qo_indptr_prefill_cpu = (
                qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
            )
            assert qo_indptr_prefill_cpu.shape[0] == num_prefills + 1

            if prefill_use_trtllm:
                # Create GPU versions
                qo_indptr_prefill_gpu = (
                    qo_indptr[prefill_start:] - qo_indptr[prefill_start]
1026
                )
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
                paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
                    prefill_start : num_reqs + 1
                ]
                # Compute max_q_len for prefill requests
                query_lens_prefill_cpu = (
                    qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
                )
                max_q_len_prefill = int(query_lens_prefill_cpu.max().item())
                attn_metadata.prefill = TRTLLMPrefill(
                    block_tables=block_table_tensor[prefill_start:],
                    seq_lens=seq_lens[prefill_start:],
                    cum_seq_lens_q=qo_indptr_prefill_gpu,
                    cum_seq_lens_kv=paged_kv_indptr_prefill_gpu,
                    max_q_len=max_q_len_prefill,
                    max_seq_len=max_seq_len,
1042
                )
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
            else:
                prefill_wrapper = self._get_prefill_wrapper()
                # Slicing CPU buffers that are only needed for FI native prefills
                paged_kv_last_page_len_prefill_cpu = self.paged_kv_last_page_len.cpu[
                    prefill_start:num_reqs
                ]
                assert paged_kv_last_page_len_prefill_cpu.shape[0] == num_prefills
                paged_kv_indptr_prefill_cpu = self.paged_kv_indptr.cpu[
                    prefill_start : num_reqs + 1
                ]
                assert paged_kv_indptr_prefill_cpu.shape[0] == num_prefills + 1
                if self.use_dcp:
                    assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
                    prefill_wrapper.plan(
                        qo_indptr_cpu=qo_indptr_prefill_cpu,
                        paged_kv_indptr_cpu=paged_kv_indptr_prefill_cpu,
                        paged_kv_indices=paged_kv_indices,
                        paged_kv_last_page_len_cpu=paged_kv_last_page_len_prefill_cpu,
                        page_size=self.page_size,
                        num_qo_heads=self.num_qo_heads,
                        dcp_world_size=self.dcp_world_size,
                        num_kv_heads=self.num_kv_heads,
                        head_dim=self.head_dim,
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
                        q_data_type=self.q_data_type,
                        kv_cache_dtype=self.kv_cache_dtype,
                        prefill_fixed_split_size=self.prefill_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
                    )
1074
                else:
1075
1076
1077
                    assert isinstance(
                        prefill_wrapper,
                        BatchPrefillWithPagedKVCacheWrapper,
1078
                    )
1079
                    prefill_wrapper.plan(
1080
1081
1082
1083
1084
1085
1086
1087
                        qo_indptr=qo_indptr_prefill_cpu,
                        paged_kv_indptr=paged_kv_indptr_prefill_cpu,
                        paged_kv_indices=paged_kv_indices,
                        paged_kv_last_page_len=paged_kv_last_page_len_prefill_cpu,
                        num_qo_heads=self.num_qo_heads,
                        num_kv_heads=self.num_kv_heads,
                        head_dim_qk=self.head_dim,
                        page_size=self.page_size,
1088
1089
1090
1091
1092
1093
                        causal=True,
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
                        q_data_type=self.q_data_type,
                        kv_data_type=self.kv_cache_dtype,
1094
                        o_data_type=self.model_config.dtype,
1095
1096
                        fixed_split_size=self.prefill_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
1097
                    )
1098
                attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper)
1099

1100
1101
1102
1103
        ## DECODE PATHWAY
        if num_decodes > 0:
            if decode_use_trtllm:
                assert num_decode_tokens % num_decodes == 0, (
1104
1105
                    "TRTLLM decode requires uniform query lengths per request. "
                    f"Got {num_decode_tokens=} and {num_decodes=}."
1106
1107
1108
1109
1110
1111
1112
                )
                attn_metadata.decode = TRTLLMDecode(
                    block_tables=block_table_tensor[:num_decodes],
                    seq_lens=seq_lens[:num_decodes],
                    max_seq_len=max_seq_len,
                )
            else:
1113
                assert seq_lens_cpu is not None
1114
                pure_decode = num_prefills == 0
1115
1116
1117
                use_cudagraph = (
                    self.enable_cuda_graph
                    and pure_decode
1118
                    and num_decode_tokens <= self._decode_cudagraph_max_bs
1119
                )
1120
                num_input_tokens = num_decode_tokens
1121

1122
                decode_wrapper = self._get_decode_wrapper(
1123
1124
                    num_input_tokens, use_cudagraph
                )
1125
1126
1127
1128
1129
                # Use the persistent buffer with padding length,
                # instead of the same address but chunked version
                # in atten_metadata when using cudagraph.
                fast_plan_decode(
                    decode_wrapper,
1130
1131
1132
1133
1134
1135
1136
1137
1138
                    indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
                    indices=paged_kv_indices,
                    last_page_len_cpu=self.paged_kv_last_page_len.cpu[
                        :num_input_tokens
                    ],
                    num_qo_heads=self.num_qo_heads * self.dcp_world_size,
                    num_kv_heads=self.num_kv_heads,
                    head_dim=self.head_dim,
                    page_size=self.page_size,
1139
1140
1141
1142
1143
1144
1145
                    # Disable flashinfer's pos encoding and use vllm's rope.
                    pos_encoding_mode="NONE",
                    sm_scale=self.sm_scale,
                    window_left=self.window_left,
                    logits_soft_cap=self.logits_soft_cap,
                    q_data_type=self.q_data_type,
                    kv_data_type=self.kv_cache_dtype,
1146
                    o_data_type=self.model_config.dtype,
1147
1148
1149
1150
                    fixed_split_size=self.decode_fixed_split_size,
                    disable_split_kv=self.disable_split_kv,
                )
                attn_metadata.decode = FIDecode(wrapper=decode_wrapper)
1151
1152
1153
        return attn_metadata

    def use_cascade_attention(self, *args, **kwargs) -> bool:
1154
        if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
1155
1156
1157
            # TODO: The cascade wrapper currently does not support setting
            # kv cache dtype to something different from query dtype.
            return False
1158
1159
1160
        # TODO: Cascade attention doesn't work, disable it for now
        # return use_cascade_attention(*args, **kwargs)
        return False
1161
1162
1163


class FlashInferImpl(AttentionImpl):
1164
1165
    can_return_lse_for_decode: bool = True

1166
1167
1168
1169
1170
1171
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
1172
1173
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
1174
        kv_cache_dtype: str,
1175
        logits_soft_cap: float | None = None,
1176
        attn_type: AttentionType = AttentionType.DECODER,
1177
1178
        kv_sharing_target_layer_name: int | None = None,
        sinks: torch.Tensor | None = None,
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
1191
1192
1193
        self.window_left = (
            self.sliding_window[0] if self.sliding_window is not None else -1
        )
1194
1195
        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
1196
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
1197
1198
1199
1200

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if attn_type != AttentionType.DECODER:
1201
1202
1203
1204
1205
1206
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashInferImpl"
            )
1207

1208
        self.sinks: torch.Tensor | None = None
1209
        if sinks is not None:
1210
1211
1212
1213
            if sinks.shape[0] != num_heads:
                raise ValueError(
                    "Sinks must have the same number of heads as the number of "
                    f"heads in the layer. Expected {num_heads}, but got "
1214
1215
                    f"{sinks.shape[0]}."
                )
1216
1217
            self.sinks = sinks

1218
        self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
1219
        vllm_config = get_current_vllm_config_or_none()
1220
1221
        self.supports_quant_query_input = (
            self.support_trtllm_attn
1222
            and vllm_config is not None
1223
1224
            and not vllm_config.attention_config.disable_flashinfer_q_quantization
        )
1225
1226
1227
        self.bmm1_scale: float | None = None
        self.bmm2_scale: float | None = None
        self.o_sf_scale: float | None = None
1228

1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
        dcp_a2a = (
            vllm_config is not None
            and vllm_config.parallel_config.decode_context_parallel_size > 1
            and vllm_config.parallel_config.dcp_comm_backend == "a2a"
        )
        if dcp_a2a:
            self.dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
        else:
            self.dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)

1239
    def fused_output_quant_supported(self, quant_key: QuantKey):
1240
1241
1242
        return (
            self.support_trtllm_attn
            and self.kv_cache_dtype.startswith("fp8")
1243
            and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
1244
        )
1245

1246
1247
1248
1249
1250
    # FlashInfer requires attention sinks to be float32
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if self.sinks is not None and self.sinks.dtype != torch.float32:
            self.sinks = self.sinks.to(torch.float32)

1251
1252
1253
1254
1255
1256
1257
1258
    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashInferMetadata,
1259
1260
1261
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
1262
1263
1264
1265
1266
1267
1268
    ) -> torch.Tensor:
        """Forward pass with FlashInfer.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
1269
1270
1271
            kv_cache: KV cache tensor with different possible shapes:
                - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
                - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
1272
1273
1274
1275
1276
1277
1278
1279
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if attn_metadata is None:
            # Profiling run.
1280
            return output.fill_(0)
1281

1282
1283
1284
1285
1286
1287
        # Ensure query dtype matches the expected dtype from attention metadata
        assert attn_metadata.q_data_type == query.dtype, (
            f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
            f"got {query.dtype}"
        )

1288
        if self.bmm1_scale is None:
1289
            self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
1290
1291
1292
1293

        if self.bmm2_scale is None:
            self.bmm2_scale = layer._v_scale_float

1294
1295
1296
        prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
        decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode)

1297
1298
        # The attn+quant fusion happens when output_scale is provided.
        if output_scale is None:
1299
1300
1301
            assert output_block_scale is None, (
                "output_block_scale is not supported when fusion has not happened"
            )
1302
        else:
1303
            assert attn_metadata.q_data_type == FP8_DTYPE, (
1304
                "Query must be FP8 when attn+quant fusion happened."
1305
            )
1306
1307
            assert (attn_metadata.num_prefills == 0 or prefill_use_trtllm) and (
                attn_metadata.num_decodes == 0 or decode_use_trtllm
1308
            ), "Must use TRT-LLM attn"
1309

1310
            if output.dtype == FP8_DTYPE:
1311
                assert output_block_scale is None, (
1312
                    "output_block_scale should not be provided for fp8 output"
1313
                )
1314
            elif output.dtype == FP4_DTYPE:
1315
                assert output_block_scale is not None, (
1316
                    "output_block_scale is required for nvfp4 output"
1317
                )
1318
1319
1320
            else:
                raise ValueError(f"Unsupported output dtype: {output.dtype}")

1321
            # TRTLLM attn kernel requires to scale to pass as a host scalar,
1322
1323
            # store the o scale as a host scalar in warmup run with cuda graph
            # not enabled
1324
1325
            if layer._o_scale_float is None:
                layer._o_scale_float = output_scale.cpu().item()
1326
1327
1328
1329
                if output.dtype == FP8_DTYPE:
                    self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
                elif output.dtype == FP4_DTYPE:
                    self.o_sf_scale = layer._o_scale_float
1330

1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens
1341

1342
1343
1344
1345
1346
1347
1348
        # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
        # to process the cache when the kv_cache_dtype is fp8
        if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
            "fp8"
        ):
            torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.kv_cache_dtype
1349
            )
1350
            kv_cache = kv_cache.view(torch_dtype)
1351

1352
1353
        # Inputs and outputs may be padded for CUDA graphs
        query = query[:num_actual_tokens]
1354
1355
        key = key[:num_actual_tokens]
        value = value[:num_actual_tokens]
1356
1357
1358
1359
1360
1361
1362
1363
1364
        output_padded = output
        output = output[:num_actual_tokens]

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

1365
1366
        # When using spec decoding, num_decodes can be < num_decode_tokens
        # because some decode requests may have more than one query token.
1367
1368
1369
        num_decode_tokens = attn_metadata.num_decode_tokens
        num_prefill_tokens = attn_metadata.num_prefill_tokens

1370
        stride_order = FlashInferBackend.get_kv_cache_stride_order()
1371
        kv_cache_permute = kv_cache.permute(*stride_order)
1372
1373
1374

        use_dcp = self.dcp_world_size > 1

1375
        # Regular attention (common case).
1376
        # Decodes are at the front and prefills are at the back.
1377
        if num_prefill_tokens > 0:
1378
1379
            prefill_query = query[num_decode_tokens:]
            assert prefill_query.shape[0] == num_prefill_tokens
1380

1381
1382
1383
1384
1385
            if not prefill_use_trtllm:
                assert isinstance(attn_metadata.prefill, FIPrefill)
                prefill_wrapper = attn_metadata.prefill.wrapper
                assert prefill_wrapper is not None
                if use_dcp:
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
                    assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
                    assert prefill_wrapper._context._window_left == self.window_left
                    assert prefill_wrapper._context._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._context._sm_scale == self.scale
                    assert not prefill_wrapper._context._causal
                    assert prefill_wrapper._new_tokens._window_left == self.window_left
                    assert prefill_wrapper._new_tokens._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._new_tokens._sm_scale == self.scale
                    assert prefill_wrapper._new_tokens._causal

                    prefill_wrapper.run(
                        layer,
                        prefill_query,
                        kv_cache_permute,
                        key[num_decode_tokens:],
                        value[num_decode_tokens:],
                        out=output[num_decode_tokens:],
                    )
                else:
                    assert isinstance(
                        prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper
                    )
                    assert prefill_wrapper._window_left == self.window_left
                    assert prefill_wrapper._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._sm_scale == self.scale
                    assert prefill_wrapper._causal
                    prefill_wrapper.run(
                        prefill_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output[num_decode_tokens:],
                    )
1425
            else:
1426
                assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
1427
1428
1429
1430
1431
                # prefill_query may be non-contiguous or have degenerate strides
                # First ensure memory contiguity, then fix degenerate strides
                # with reshape. contiguous() alone doesn't fix degenerate
                # strides when a dimension has size 1.
                prefill_query = prefill_query.contiguous().reshape(prefill_query.shape)
1432
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
1433
1434
                block_tables_prefill = attn_metadata.prefill.block_tables
                seq_lens_prefill = attn_metadata.prefill.seq_lens
1435
1436
1437

                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                assert get_kv_cache_layout() == "HND"
1438
1439
1440
1441
1442
                assert is_strictly_contiguous(prefill_query)
                assert is_strictly_contiguous(kv_cache_permute)
                assert is_strictly_contiguous(workspace_buffer)
                assert is_strictly_contiguous(block_tables_prefill)
                assert is_strictly_contiguous(seq_lens_prefill)
1443

1444
1445
                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
1446
1447
1448
1449
1450
1451
                    out = FP4Tensor(
                        data=output[num_decode_tokens:],
                        scale=output_block_scale,
                        scale_start_index=num_decode_tokens,
                        original_shape=prefill_query.shape,
                    )
1452
1453
1454
1455
                else:
                    assert self.o_sf_scale is None
                    out = output[num_decode_tokens:]

1456
1457
1458
1459
                if (
                    attn_metadata.q_data_type != FP8_DTYPE
                    and self.kv_cache_dtype.startswith("fp8")
                ):
1460
1461
1462
1463
                    # TRTLLM prefill attention does not support BF16 Q
                    # and fp8 kv cache. So to enable prefill attention
                    # with fp8 kv cache, we can construct a mock block
                    # and mock kv cache with BF16 KV involved in the prefill
1464
1465
1466
1467
1468
1469
1470
                    mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
                        kv_cache_permute,
                        block_tables_prefill,
                        layer._k_scale,
                        layer._v_scale,
                        attn_metadata.q_data_type,
                    )
1471
1472
1473
1474
                else:
                    mock_kv_cache = kv_cache_permute
                    mock_block_table = block_tables_prefill

1475
1476
                trtllm_batch_context_with_kv_cache(
                    query=prefill_query,
1477
                    kv_cache=mock_kv_cache,
1478
                    workspace_buffer=workspace_buffer,
1479
                    block_tables=mock_block_table,
1480
                    seq_lens=seq_lens_prefill,
1481
1482
                    max_q_len=attn_metadata.prefill.max_q_len,
                    max_kv_len=attn_metadata.prefill.max_seq_len,
1483
1484
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
1485
                    batch_size=attn_metadata.num_prefills,
1486
1487
                    cum_seq_lens_q=attn_metadata.prefill.cum_seq_lens_q,
                    cum_seq_lens_kv=attn_metadata.prefill.cum_seq_lens_kv,
1488
                    window_left=self.window_left,
1489
                    sinks=self.sinks,
1490
1491
                    o_sf_scale=self.o_sf_scale,
                    out=out,
1492
1493
1494
                )

        if num_decode_tokens > 0:
1495
1496
            decode_query = query[:num_decode_tokens]
            assert decode_query.shape[0] == num_decode_tokens
1497

1498
1499
1500
1501
            if not decode_use_trtllm:
                assert isinstance(attn_metadata.decode, FIDecode)
                decode_wrapper = attn_metadata.decode.wrapper
                assert decode_wrapper is not None
1502
                assert decode_wrapper._window_left == self.window_left
1503
                assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
1504
                assert decode_wrapper._sm_scale == self.scale
1505

1506
                if use_dcp:
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
                    decode_query = get_dcp_group().all_gather(
                        decode_query.contiguous(), dim=-2
                    )
                    output_tmp = torch.empty_like(decode_query)
                    lse = torch.empty(
                        (decode_query.size(0), decode_query.size(1)),
                        dtype=torch.float32,
                        device=decode_query.device,
                    )
                    decode_wrapper.run(
                        decode_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output_tmp,
                        lse=lse,
                        return_lse=True,
                    )
1525
                    output[:num_decode_tokens] = self.dcp_combine(
1526
1527
1528
                        output_tmp,
                        lse,
                        get_dcp_group(),
1529
1530
1531
1532
1533
1534
1535
1536
1537
                    )
                else:
                    decode_wrapper.run(
                        decode_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output[:num_decode_tokens],
                    )
1538
            else:
1539
                # decode_query may be non-contiguous or have degenerate strides
1540
                assert isinstance(attn_metadata.decode, TRTLLMDecode)
1541
1542
1543
1544
                # First ensure memory contiguity, then fix degenerate strides
                # with reshape. contiguous() alone doesn't fix degenerate
                # strides when a dimension has size 1.
                decode_query = decode_query.contiguous().reshape(decode_query.shape)
1545
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
1546
1547
                block_tables_decode = attn_metadata.decode.block_tables
                seq_lens_decode = attn_metadata.decode.seq_lens
1548

1549
                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
1550
                assert get_kv_cache_layout() == "HND"
1551
1552
1553
1554
1555
                assert is_strictly_contiguous(decode_query)
                assert is_strictly_contiguous(kv_cache_permute)
                assert is_strictly_contiguous(workspace_buffer)
                assert is_strictly_contiguous(block_tables_decode)
                assert is_strictly_contiguous(seq_lens_decode)
1556

1557
1558
                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
1559
1560
1561
1562
1563
1564
                    out = FP4Tensor(
                        data=output[:num_decode_tokens],
                        scale=output_block_scale,
                        scale_start_index=0,
                        original_shape=decode_query.shape,
                    )
1565
1566
1567
1568
                else:
                    assert self.o_sf_scale is None
                    out = output[:num_decode_tokens]

1569
1570
1571
1572
1573
                if num_decode_tokens % attn_metadata.num_decodes != 0:
                    # This gets triggered when the dummy_run forces
                    # attention to be initialized with q_len = 0
                    q_len_per_req = 1
                else:
1574
                    q_len_per_req = num_decode_tokens // attn_metadata.num_decodes
1575

1576
1577
1578
1579
1580
1581
                trtllm_batch_decode_with_kv_cache(
                    query=decode_query,
                    kv_cache=kv_cache_permute,
                    workspace_buffer=workspace_buffer,
                    block_tables=block_tables_decode,
                    seq_lens=seq_lens_decode,
1582
                    max_seq_len=attn_metadata.decode.max_seq_len,
1583
1584
1585
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
                    window_left=self.window_left,
1586
                    sinks=self.sinks,
1587
1588
                    o_sf_scale=self.o_sf_scale,
                    out=out,
1589
1590
                    q_len_per_req=q_len_per_req,
                )
1591
        return output_padded
1592

1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
    def do_kv_cache_update(
        self,
        layer: torch.nn.Module,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> None:
        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631

def fast_plan_decode(
    self,  # decode wrapper
    indptr_cpu: torch.Tensor,
    indices: torch.Tensor,
    last_page_len_cpu: torch.Tensor,
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    window_left: int = -1,
1632
    logits_soft_cap: float | None = None,
1633
1634
    q_data_type: str | torch.dtype | None = "float16",
    kv_data_type: str | torch.dtype | None = None,
1635
    o_data_type: str | torch.dtype | None = None,
1636
    data_type: str | torch.dtype | None = None,
1637
1638
1639
    sm_scale: float | None = None,
    rope_scale: float | None = None,
    rope_theta: float | None = None,
1640
    non_blocking: bool = True,
1641
1642
    fixed_split_size: int = -1,
    disable_split_kv: bool = False,
1643
1644
) -> None:
    """
1645
1646
    A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
    cudagraph capture/replay, while the no cudagraph version turns back
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
    to the original plan.
    using original plan after passing host-side buffers:
    - only host-to-device copy of indptr and last_page_len buffers
    Modifications for cudagraph:
    - only host-to-device copy of indptr and last_page_len buffers.
    - avoid device-to-device copy of indices buffer.

    Part of the code get inspiration from the original plan from FlashInfer repo
    and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
    """
    # Warm up with the original plan if it is first call, and always run the
    # original plan if we run for dynamic shape. For fixed shape (cudagraph),
    # this warm up is to generate the _cached_module for the decode wrapper.
1660
    if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
1661
        self.plan(
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
            indptr=indptr_cpu,
            indices=indices,
            last_page_len=last_page_len_cpu,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim=head_dim,
            page_size=page_size,
            pos_encoding_mode=pos_encoding_mode,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
            kv_data_type=kv_data_type,
            o_data_type=o_data_type,
            data_type=data_type,
            sm_scale=sm_scale,
            rope_scale=rope_scale,
            rope_theta=rope_theta,
            non_blocking=non_blocking,
            block_tables=None,
            seq_lens=None,
            fixed_split_size=fixed_split_size,
            disable_split_kv=disable_split_kv,
1684
1685
1686
1687
1688
1689
        )
        self.vllm_first_call = False
        return

    assert self.is_cuda_graph_enabled, "Should be cudagraph only here"

1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
    fast_decode_plan(
        self,
        indptr=indptr_cpu,
        indices=indices,
        last_page_len=last_page_len_cpu,
        num_qo_heads=num_qo_heads,
        num_kv_heads=num_kv_heads,
        head_dim=head_dim,
        page_size=page_size,
        pos_encoding_mode=pos_encoding_mode,
        window_left=window_left,
        logits_soft_cap=logits_soft_cap,
        q_data_type=q_data_type,
        kv_data_type=kv_data_type,
        data_type=data_type,
        sm_scale=sm_scale,
        rope_scale=rope_scale,
        rope_theta=rope_theta,
        non_blocking=non_blocking,
        fixed_split_size=fixed_split_size,
        disable_split_kv=disable_split_kv,
1711
    )
1712

1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730

@triton.jit
def _copy_page_indices_kernel(
    page_indices,
    block_table,
    block_table_stride,
    cu_num_blocks,
    BLOCK_SIZE: tl.constexpr,
):
    req_idx = tl.program_id(0)
    row_ptr = block_table + req_idx * block_table_stride
    start_idx = tl.load(cu_num_blocks + req_idx)
    end_idx = tl.load(cu_num_blocks + req_idx + 1)
    num_blocks = end_idx - start_idx

    offset = tl.arange(0, BLOCK_SIZE)
    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
1731
1732
1733
1734
1735
        tl.store(
            page_indices + start_idx + i + offset,
            block_ids,
            mask=i + offset < num_blocks,
        )