flashinfer.py 69.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer with FlashInfer."""
4

5
from dataclasses import dataclass
6
from functools import partial
7
from typing import ClassVar
8

9
import numpy as np
10
import torch
11
12
13
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
14
    BatchPrefillWithRaggedKVCacheWrapper,
15
16
    MultiLevelCascadeAttentionWrapper,
)
17
from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache
18
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
19
from flashinfer.utils import FP4Tensor
20
from typing_extensions import override
21

22
from vllm import envs
23
24
25
26
27
from vllm.config import (
    CUDAGraphMode,
    VllmConfig,
    get_current_vllm_config_or_none,
)
28
from vllm.config.cache import CacheDType
29
from vllm.distributed.parallel_state import get_dcp_group
30
from vllm.logger import init_logger
31
from vllm.model_executor.layers.batch_invariant import (
32
    vllm_is_batch_invariant,
33
)
34
from vllm.model_executor.layers.quantization.utils.quant_utils import (
35
36
    QuantKey,
    kFp8StaticTensorSym,
37
    kNvfp4Dynamic,
38
)
39
from vllm.platforms import current_platform
40
from vllm.platforms.interface import DeviceCapability
41
from vllm.triton_utils import tl, triton
42
43
44
45
from vllm.utils.flashinfer import (
    can_use_trtllm_attention,
    use_trtllm_attention,
)
46
from vllm.utils.math_utils import cdiv
47
from vllm.utils.platform_utils import is_pin_memory_available
48
from vllm.utils.torch_utils import is_strictly_contiguous
49
50
from vllm.v1.attention.backend import (
    AttentionBackend,
51
    AttentionCGSupport,
52
    AttentionImpl,
53
    AttentionMetadataBuilder,
54
    AttentionType,
55
    CommonAttentionMetadata,
56
57
    MultipleOf,
)
58
from vllm.v1.attention.backends.utils import (
59
    KVCacheLayoutType,
60
    get_dcp_local_seq_lens,
61
62
63
64
65
    get_kv_cache_layout,
    get_per_layer_parameters,
    infer_global_hyperparameters,
    split_decodes_and_prefills,
)
66
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
67
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
68
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
69
from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs
70
from vllm.v1.utils import CpuGpuBuffer
71

72
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
73

74
FP8_DTYPE = current_platform.fp8_dtype()
75
FP4_DTYPE = torch.uint8
76

77
78
logger = init_logger(__name__)

79
80
81
82
83
84
85
trtllm_gen_workspace_buffer = None


def _get_trtllm_gen_workspace_buffer():
    global trtllm_gen_workspace_buffer
    if trtllm_gen_workspace_buffer is None:
        trtllm_gen_workspace_buffer = torch.zeros(
86
            envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
87
        )
88
89
    return trtllm_gen_workspace_buffer

90

91
92
93
94
95
96
97
98
99
100
101
102
103
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
    kv_cache_ptr,
    block_tables_prefill_ptr,
    block_table_stride,
    mock_kv_cache_ptr,
    k_scale_ptr,
    v_scale_ptr,
    K_CACHE_STRIDE: tl.constexpr,
    KV_CACHE_STRIDE: tl.constexpr,
):
    batch_idx = tl.program_id(0).to(tl.int64)
    mock_block_table_idx = tl.program_id(1).to(tl.int64)
104
105
106
    orig_page_num = tl.load(
        block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx
    ).to(tl.int64)
107
108
109
110
111
112
113
114
115
    if orig_page_num <= 0:
        return
    dequant_dtype = mock_kv_cache_ptr.dtype.element_ty

    # Dequantize K
    k_scale_val = tl.load(k_scale_ptr)
    offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
116
117
118
    mock_cache_offset = (
        batch_idx * block_table_stride + mock_block_table_idx + 1
    ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
119
120
121
122
123
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)

    # Dequantize V
    v_scale_val = tl.load(v_scale_ptr)
124
125
126
    offset = (
        orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    )
127
128
129
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
    mock_cache_offset = (
130
131
132
133
        (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE
        + K_CACHE_STRIDE
        + tl.arange(0, K_CACHE_STRIDE)
    )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)


def trtllm_prefill_attn_kvfp8_dequant(
    kv_cache: torch.Tensor,
    block_tables_prefill: torch.Tensor,
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    dequant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
    batch_size, num_of_page_per_token = block_tables_prefill.shape
    s = kv_cache.shape
    assert s[1] == 2
    assert dequant_dtype in (torch.bfloat16, torch.float16)
    k_cache_stride = s[2] * s[3] * s[4]
    kv_cache_stride = k_cache_stride * s[1]
    new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
    # mock kv cache contains just the pages needed by this prefill
153
    mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    # we simply sequentially index the pages needed by this prefill
    mock_block_table = torch.arange(
        start=1,
        end=batch_size * num_of_page_per_token + 1,
        dtype=torch.int32,
        device=block_tables_prefill.device,
    ).reshape(batch_size, num_of_page_per_token)
    grid = (batch_size, num_of_page_per_token)
    _trtllm_prefill_attn_kvfp8_dequant[grid](
        kv_cache,
        block_tables_prefill,
        num_of_page_per_token,
        mock_kv_cache,
        k_scale,
        v_scale,
        k_cache_stride,
        kv_cache_stride,
    )
    return mock_kv_cache, mock_block_table

174

175
176
177
178
class BatchDCPPrefillWrapper:
    def __init__(
        self,
        workspace_buffer: torch.Tensor | None = None,
179
        dcp_a2a: bool = False,
180
    ):
181
182
183
184
        if dcp_a2a:
            self._dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
        else:
            self._dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        self._context = BatchPrefillWithPagedKVCacheWrapper(
            workspace_buffer, get_kv_cache_layout()
        )
        self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(
            workspace_buffer, get_kv_cache_layout()
        )

    def plan(
        self,
        qo_indptr_cpu: torch.Tensor,
        paged_kv_indptr_cpu: torch.Tensor,
        paged_kv_indices: torch.Tensor,
        paged_kv_last_page_len_cpu: torch.Tensor,
        page_size: int,
        num_qo_heads: int,
        dcp_world_size: int,
        num_kv_heads: int,
        head_dim: int,
        sm_scale: float,
        window_left: int,
        logits_soft_cap: float | None,
        q_data_type: torch.dtype,
        kv_cache_dtype: torch.dtype,
        prefill_fixed_split_size: int,
        disable_split_kv: bool,
    ):
        """Plan the prefill operation with given parameters."""
        self._context.plan(
213
214
215
216
217
218
219
220
            qo_indptr=qo_indptr_cpu,
            paged_kv_indptr=paged_kv_indptr_cpu,
            paged_kv_indices=paged_kv_indices,
            paged_kv_last_page_len=paged_kv_last_page_len_cpu,
            num_qo_heads=num_qo_heads * dcp_world_size,
            num_kv_heads=num_kv_heads,
            head_dim_qk=head_dim,
            page_size=page_size,
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
            causal=False,  # This is context run
            sm_scale=sm_scale,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
            kv_data_type=kv_cache_dtype,
            fixed_split_size=prefill_fixed_split_size,
            disable_split_kv=disable_split_kv,
        )
        self._new_tokens.plan(
            qo_indptr=qo_indptr_cpu,
            kv_indptr=qo_indptr_cpu,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim_qk=head_dim,
            head_dim_vo=head_dim,
            causal=True,  # This is newtokens run
            sm_scale=sm_scale,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
        )

    def run(
        self,
        layer: torch.nn.Module,
        prefill_query: torch.Tensor,
        kv_cache_permute: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        out: torch.Tensor,
    ):
        prefill_query_across_dcp = get_dcp_group().all_gather(
            prefill_query.contiguous(), dim=1
        )
        output_context_tmp, lse_context_tmp = self._context.run(
            prefill_query_across_dcp,
            kv_cache_permute,
            k_scale=layer._k_scale_float,
            v_scale=layer._v_scale_float,
            return_lse=True,
        )
263
        output_context, lse_context = self._dcp_combine(
264
265
266
267
            output_context_tmp,
            lse_context_tmp,
            get_dcp_group(),
            return_lse=True,
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        )
        lse_context = lse_context.transpose(0, 1).contiguous()

        output_query, lse_query = self._new_tokens.run(
            prefill_query,
            key,
            value,
            return_lse=True,
        )
        lse_query = lse_query.transpose(0, 1).contiguous()

        merge_attn_states(
            out,
            output_context,
            lse_context,
            output_query,
            lse_query,
        )
        return out


289
class FlashInferBackend(AttentionBackend):
290
    accept_output_buffer: bool = True
291
292
293
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
294
        "float16",
295
        "bfloat16",
296
297
298
299
        "fp8",
        "fp8_e4m3",
        "fp8_e5m2",
    ]
300

301
302
303
304
305
306
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        # Note: Not sure for all platforms, but on Blackwell,
        # only support a page size of 16, 32, 64.
        return [16, 32, 64]

307
308
    @staticmethod
    def get_name() -> str:
309
        return "FLASHINFER"
310
311

    @staticmethod
312
    def get_impl_cls() -> type["FlashInferImpl"]:
313
314
315
        return FlashInferImpl

    @staticmethod
316
    def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
317
318
319
320
321
322
323
324
        return FlashInferMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
325
        cache_dtype_str: str = "auto",
326
327
328
    ) -> tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

329
    @staticmethod
330
331
332
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
333
334
335
        # `stride_order` indicates the permutation that gets us from
        # `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
336
337
338
339
        if cache_layout == "NHD" and include_num_layers_dimension:
            # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
            return (1, 0, 2, 3, 4, 5)
        elif cache_layout == "NHD":
340
            stride_order = (0, 1, 2, 3, 4)
341
342
343
        elif cache_layout == "HND" and include_num_layers_dimension:
            # (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
            return (1, 2, 4, 0, 3, 5)
344
345
346
347
348
349
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

350
351
352
353
354
355
356
357
358
    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

359
360
361
362
363
364
365
366
367
368
369
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        return [64, 128, 256]

    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability(
            12, 1
        )

370
371
372
373
374
375
376
377
    @classmethod
    def supports_sink(cls) -> bool:
        """FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
        from vllm.utils.flashinfer import (
            force_use_trtllm_attention,
            supports_trtllm_attention,
        )

378
379
        # Respect explicit disable flag (e.g.,
        # --attention-config.use_trtllm_attention=0)
380
381
382
383
384
385
        if force_use_trtllm_attention() is False:
            return False

        # Check if TRTLLM is supported on this platform
        return supports_trtllm_attention()

386
387
388
389
390
391
392
    @classmethod
    def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
        capability = current_platform.get_device_capability()
        if capability is not None and capability.major == 10:
            return "HND"
        return None

393
394
    forward_includes_kv_cache_update: bool = False

395
396

@dataclass
397
398
class FIPrefill:
    """Metadata for the native FlashInfer prefill pathway (non-TRTLLM)."""
399

400
    wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper
401
402


403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
@dataclass
class FIDecode:
    """Metadata for the native FlashInfer decode pathway (non-TRTLLM)."""

    wrapper: BatchDecodeWithPagedKVCacheWrapper


@dataclass
class TRTLLMPrefill:
    """Metadata for the TRTLLM prefill pathway."""

    block_tables: torch.Tensor
    """
    The slice of the block table tensor corresponding *only* to prefill requests.
    Shape: [num_prefills, max_num_blocks_per_seq]
    """

    seq_lens: torch.Tensor
    """
    The slice of the sequence lengths tensor corresponding *only* to prefill requests.
    Shape: [num_prefills]
    """

    cum_seq_lens_q: torch.Tensor
    cum_seq_lens_kv: torch.Tensor

429
    max_q_len: int
430
    """
431
    The maximum query length *among prefill requests*.
432
433
    """

434
    max_seq_len: int
435
436
437
438
439
440
441
442
443
444
445
446
447
    """The maximum sequence length for KV Cache."""


@dataclass
class TRTLLMDecode:
    """Metadata for the TRTLLM decode pathway."""

    block_tables: torch.Tensor
    """
    The slice of the block table tensor corresponding *only* to decode requests.
    Shape: [num_decodes, max_num_blocks_per_seq]
    """

448
    seq_lens: torch.Tensor
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    """
    The slice of the sequence lengths tensor corresponding *only* to decode requests.
    Shape: [num_decodes]
    """

    max_seq_len: int
    """The maximum sequence length for KV Cache."""


@dataclass
class FlashInferMetadata:
    num_actual_tokens: int
    """Total number of tokens in the batch (excluding padding)."""

    slot_mapping: torch.Tensor
    """Tensor for writing K/V to the cache. Shape: [num_actual_tokens]"""

    q_data_type: torch.dtype
467

468
469
470
471
472
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

473
474
475
476
477
    prefill: FIPrefill | TRTLLMPrefill | None
    """
    Holds the metadata for the prefill portion of the batch.
    Will be `None` if `num_prefill_tokens == 0`.
    """
478

479
480
481
482
483
    decode: FIDecode | TRTLLMDecode | None
    """
    Holds the metadata for the decode portion of the batch.
    Will be `None` if `num_decode_tokens == 0`.
    """
484

485
486
487
488
489
490
491
492
493
    # --- Special Case: Cascade Attention ---

    use_cascade: bool
    """
    If True, the entire batch is a cascade attention call, and the
    `prefill` and `decode` fields will both be None.
    """

    cascade_wrapper: MultiLevelCascadeAttentionWrapper | None
494

495

496
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
497
    reorder_batch_threshold: int = 1
498

499
500
501
502
503
504
505
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
506
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
507
        self.cache_config = vllm_config.cache_config
508
        self.model_config = vllm_config.model_config
509
        self.attention_config = vllm_config.attention_config
510
        self._workspace_buffer = None
511
512
513
        self._prefill_wrapper: (
            BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
        ) = None  # Wrapper for prefill/append
514
515
        self._decode_wrapper = None  # Wrapper for decode (general shape)

516
        if vllm_is_batch_invariant():
517
518
519
520
521
522
523
524
            self.decode_fixed_split_size = 2048
            self.prefill_fixed_split_size = 4096
            self.disable_split_kv = True
        else:
            self.decode_fixed_split_size = -1
            self.prefill_fixed_split_size = -1
            self.disable_split_kv = False

525
        self.compilation_config = vllm_config.compilation_config
526
527
528
        max_num_pages_per_req = cdiv(
            self.model_config.max_model_len, self.kv_cache_spec.block_size
        )
529
530
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        max_num_pages = max_num_reqs * max_num_pages_per_req
531
532
533
534
535
536
        speculative_config = vllm_config.speculative_config
        num_spec_tokens = (
            speculative_config.num_speculative_tokens
            if speculative_config is not None
            else 0
        )
537
538
539
        self.enable_cuda_graph = (
            self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
        )
540
541
542
543
        if self.enable_cuda_graph:
            # For full cudagraph capture, one `decode_wrapper` for each batch
            # size is needed for FlashInfer.
            self._decode_wrappers_cudagraph: dict[
544
545
                int, BatchDecodeWithPagedKVCacheWrapper
            ] = {}
546
547
548
549
550
551
            self._decode_cudagraph_max_bs = (1 + num_spec_tokens) * max_num_reqs
            if self.compilation_config.max_cudagraph_capture_size is not None:
                self._decode_cudagraph_max_bs = min(
                    self._decode_cudagraph_max_bs,
                    self.compilation_config.max_cudagraph_capture_size,
                )
552
553
554
555
556
557
558
559
560
561
562
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
            self.dcp_kv_cache_interleave_size = (
                vllm_config.parallel_config.dcp_kv_cache_interleave_size
            )
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
            self.dcp_kv_cache_interleave_size = 1
563
        self.use_dcp = self.dcp_world_size > 1
564
565
566
        self.dcp_a2a = (
            self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a"
        )
567

568
569
        self.num_qo_heads = self.model_config.get_num_attention_heads(
            self.vllm_config.parallel_config
570
        )
571

572
573
574
575
576
577
        self.num_kv_heads = self.kv_cache_spec.num_kv_heads
        self.head_dim = self.kv_cache_spec.head_size
        self.page_size = self.kv_cache_spec.block_size

        self.cache_dtype = self.cache_config.cache_dtype
        if self.cache_dtype.startswith("fp8"):
578
579
580
            self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.cache_dtype
            )
581
        else:
582
            assert self.kv_cache_spec.dtype == self.model_config.dtype
583
            self.kv_cache_dtype = self.kv_cache_spec.dtype
584

585
        # Use model dtype as q dtype when TRTLLM attn is not supported, or
586
587
        # --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
        # try to use fp8 q if kv cache is fp8, and will fall back to model dtype
588
        # if TRTLLM attention kernel is not used when building attn metadata
589
        can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
590

591
592
593
594
        if (
            can_use_trtllm
            and not vllm_config.attention_config.disable_flashinfer_q_quantization
        ):
595
596
597
            self.q_data_type = self.kv_cache_dtype
        else:
            self.q_data_type = self.model_config.dtype
598

599
600
601
        # Prefer TRTLLM attention for decoding in all cases.
        # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
        self.use_trtllm_decode_attention = can_use_trtllm
602
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
603

604
605
606
        self._cascade_wrapper = None  # Wrapper for cascade attention

        # Global hyperparameters shared by all attention layers
607
        # TODO: discard this for trtllm-gen backend
608
        self.global_hyperparameters = infer_global_hyperparameters(
609
610
            get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)
        )
611
612
613
614
        self.sm_scale = self.global_hyperparameters.sm_scale
        self.window_left = self.global_hyperparameters.window_left
        self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
        self.has_sinks = self.global_hyperparameters.has_sinks
615
        if self.has_sinks and not can_use_trtllm:
616
617
618
            raise NotImplementedError(
                "FlashInfer backend currently does not support attention "
                "sinks, please use trtllm on blackwell or flash attention on "
619
620
                "earlier GPUs."
            )
621
        # Preparing persistent buffers
622
623
624
625
626
627
        # Since we do not have explicit synchronization in ModelRunnerV2, we do not pin
        # reused CPU buffers to avoid a race condition between step N async copies to
        # GPU and step N+1 buffer updates.
        self.pin_memory = (
            not envs.VLLM_USE_V2_MODEL_RUNNER and is_pin_memory_available()
        )
628
629
630
631
632
633
        self.paged_kv_indptr = self._make_buffer(max_num_reqs + 1)
        self.paged_kv_indptr_cpu_buffer = torch.zeros_like(
            self.paged_kv_indptr.cpu, pin_memory=self.pin_memory
        )  # Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode
        self.paged_kv_indices = self._make_buffer(max_num_pages)
        self.paged_kv_last_page_len = self._make_buffer(max_num_reqs)
634

635
636
637
638
639
640
641
642
643
644
645
646
    def _make_buffer(
        self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32
    ) -> CpuGpuBuffer:
        return CpuGpuBuffer(
            *size,
            dtype=dtype,
            device=self.device,
            pin_memory=self.pin_memory,
            with_numpy=True,
        )

    @override  # type: ignore[misc]
647
648
649
650
651
652
    @classmethod
    def get_cudagraph_support(
        cls: type["FlashInferMetadataBuilder"],
        vllm_config: VllmConfig,
        kv_cache_spec: AttentionSpec,
    ) -> AttentionCGSupport:
653
654
655
656
657
658
659
660
661
662
663
664
665
        """Get the cudagraph support level for FlashInfer attention.

        This depends on whether we can use TRTLLM attention for decodes, since we can
        only do UNIFORM_SINGLE_TOKEN_DECODE if it is unavailable.
        To check this, we must call can_use_trtllm_attention with the number of KV
        heads from the kv_cache_spec. We check all available KV cache specs and
        only return UNIFORM_BATCH if all of them support TRTLLM attention.
        """
        # For UniformTypeKVCacheSpecs, check all contained specs
        kv_specs = (
            kv_cache_spec.kv_cache_specs.values()
            if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs)
            else [kv_cache_spec]
666
        )
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        num_qo_heads = vllm_config.model_config.get_num_attention_heads(
            vllm_config.parallel_config
        )
        has_trtllm_support: bool = len(kv_specs) > 0
        for spec in kv_specs:
            if not isinstance(spec, AttentionSpec):
                # FlashInfer only applies to attention, so we don't consider other types
                # of KV spec (e.g. Mamba) here. This is mostly for type checking.
                continue
            if not can_use_trtllm_attention(
                num_qo_heads=num_qo_heads,
                num_kv_heads=spec.num_kv_heads,
            ):
                has_trtllm_support = False
                break

683
684
685
686
687
        if has_trtllm_support:
            return AttentionCGSupport.UNIFORM_BATCH
        else:
            return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

688
689
    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
690
            buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
691
            if vllm_is_batch_invariant():
692
                buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
693
            self._workspace_buffer = torch.zeros(
694
                buffer_size, dtype=torch.uint8, device=self.device
695
            )
696
697
        return self._workspace_buffer

Woosuk Kwon's avatar
Woosuk Kwon committed
698
699
700
    def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
        self._workspace_buffer = workspace_buffer

701
702
703
    def _get_prefill_wrapper(
        self,
    ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
704
        if self._prefill_wrapper is None:
705
            if self.use_dcp:
706
707
                self._prefill_wrapper = BatchDCPPrefillWrapper(
                    workspace_buffer=self._get_workspace_buffer(),
708
                    dcp_a2a=self.dcp_a2a,
709
710
711
712
713
714
                )
            else:
                self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                    self._get_workspace_buffer(), get_kv_cache_layout()
                )
        assert self._prefill_wrapper is not None
715
716
        return self._prefill_wrapper

717
    def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
718
        if use_cudagraph:
719
            decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None)
720
721
722
723
724
        else:
            decode_wrapper = self._decode_wrapper

        if decode_wrapper is None:
            if use_cudagraph:
725
726
727
                paged_kv_indptr = self.paged_kv_indptr.gpu[: batch_size + 1]
                paged_kv_indices = self.paged_kv_indices.gpu
                paged_kv_last_page_len = self.paged_kv_last_page_len.gpu[:batch_size]
728
729
730
731
732
            else:
                paged_kv_indptr = None
                paged_kv_indices = None
                paged_kv_last_page_len = None
            decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
733
                self._get_workspace_buffer(),
734
                get_kv_cache_layout(),
735
736
737
738
                use_cuda_graph=use_cudagraph,
                paged_kv_indptr_buffer=paged_kv_indptr,
                paged_kv_indices_buffer=paged_kv_indices,
                paged_kv_last_page_len_buffer=paged_kv_last_page_len,
739
                # Tensor cores are enabled by default because the perf would be
co63oc's avatar
co63oc committed
740
                # at least as good as cuda cores for all attention ops in latest
741
742
743
                # gpus.
                use_tensor_cores=True,
            )
744
745
746
747
748
749
750
751

            # save the decode wrapper
            if use_cudagraph:
                self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
            else:
                self._decode_wrapper = decode_wrapper

        return decode_wrapper
752
753
754
755

    def _get_cascade_wrapper(self):
        if self._cascade_wrapper is None:
            self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
756
757
                2, self._get_workspace_buffer(), get_kv_cache_layout()
            )
758
759
        return self._cascade_wrapper

760
    def _compute_flashinfer_kv_metadata(
761
        self,
762
763
764
765
766
767
768
769
770
        num_blocks_np: np.ndarray,
        seq_lens_np: np.ndarray,
        block_table_tensor: torch.Tensor,
        num_reqs: int,
        page_size: int,
    ) -> torch.Tensor:
        """
        Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer
        attention.
771

772
773
        Results are stored in self.paged_kv_indptr,
        self.paged_kv_indices, self.paged_kv_last_page_len buffers.
774

775
776
        Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages].
        """
777
778
779
780
        # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
        np.cumsum(
            num_blocks_np,
            dtype=np.int32,
781
            out=self.paged_kv_indptr.np[1 : num_reqs + 1],
782
        )
783
784
785
        # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
        # after this line (e.g., for cuda graphs), we need to copy the data to
        # self.paged_kv_indptr_buffer to avoid race condition.
786
        self.paged_kv_indptr_cpu_buffer[: num_reqs + 1] = self.paged_kv_indptr.cpu[
787
788
            : num_reqs + 1
        ]
789
        paged_kv_indptr = self.paged_kv_indptr.gpu[: num_reqs + 1]
790
        paged_kv_indptr.copy_(
791
            self.paged_kv_indptr_cpu_buffer[: num_reqs + 1], non_blocking=True
792
        )
793

794
        # write self.paged_kv_indices inplace
795
796
        num_actual_pages = self.paged_kv_indptr.np[num_reqs]
        paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages]
797
        _copy_page_indices_kernel[(num_reqs,)](
798
799
800
801
802
803
            paged_kv_indices,
            block_table_tensor,
            block_table_tensor.stride(0),
            paged_kv_indptr,
            BLOCK_SIZE=1024,
        )
804

805
        # write self.paged_kv_last_page_len_cpu inplace
806
        paged_kv_last_page_len_np = seq_lens_np % page_size
807
        self.paged_kv_last_page_len.np[:num_reqs] = np.where(
808
            (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
809
810
811
            page_size,
            paged_kv_last_page_len_np,
        )
812
813
814
        self.paged_kv_last_page_len.gpu[:num_reqs].copy_(
            self.paged_kv_last_page_len.cpu[:num_reqs], non_blocking=True
        )
815
        return paged_kv_indices
816

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashInferMetadata:
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold,
                require_uniform=True,
            )
        )

        page_size = self.page_size
        max_seq_len = common_attn_metadata.max_seq_len
        seq_lens = common_attn_metadata.seq_lens
        block_table_tensor = common_attn_metadata.block_table_tensor
        qo_indptr = common_attn_metadata.query_start_loc
        qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu

        # Step 1: Decide which dispatch modes to use:
        # - Cascade attention (distinct mode)
        # - Prefill (FI native or TRTLLM)
        # - Decode (FI native or TRTLLM)
        use_cascade = common_prefix_len > 0
845
        uses_spec_reorder = self.reorder_batch_threshold > 1
846
847
848
849
850
        prefill_use_trtllm = use_trtllm_attention(
            self.num_qo_heads,
            self.num_kv_heads,
            num_prefill_tokens,
            max_seq_len,
851
            self.dcp_world_size,
852
853
854
            self.cache_dtype,
            self.q_data_type,
            is_prefill=True,
855
            force_use_trtllm=self.attention_config.use_trtllm_attention,
856
857
858
            has_sinks=self.has_sinks,
            has_spec=uses_spec_reorder,
        )
859
860
861
        decode_use_trtllm = (
            self.use_trtllm_decode_attention and self.dcp_world_size <= 1
        )
862

863
864
865
866
867
868
869
870
        all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and (
            num_decodes == 0 or decode_use_trtllm
        )
        is_only_trtllm_decode = num_prefills == 0 and (
            num_decodes > 0 and decode_use_trtllm
        )

        if not all_uses_trtllm:
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
            if self.has_sinks:
                raise NotImplementedError(
                    "FlashInfer backend currently does not support attention "
                    "sinks, please use trtllm on blackwell or flash attention "
                    "on earlier GPUs."
                )

            if not self.global_hyperparameters.has_same_window_lefts:
                raise ValueError(
                    "Window left is not the same for all layers. "
                    "One potential fix is to set disable_sliding_window=True"
                )

            assert self.global_hyperparameters.has_same_all_params, (
                "FlashInfer backend currently only supports models in which "
                "all layers share the same values for the following "
                "hyperparameters: `window_left`, `logits_soft_cap`, "
                "`sm_scale`."
            )

            # The q quantization is not supported for non-trtllm attention,
            # fall back to model dtype.
893
894
            self.q_data_type = self.model_config.dtype

895
896
897
        # Step 2: Initialize the output metadata
        # Leave prefill/decode/cascade_wrapper empty, to be populated
        # case by case depending on the batch contents and backend selection.
898
899
        attn_metadata = FlashInferMetadata(
            num_actual_tokens=num_actual_tokens,
900
            slot_mapping=common_attn_metadata.slot_mapping,
901
            q_data_type=self.q_data_type,
902
903
904
905
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
906
            use_cascade=use_cascade,
907
908
909
            prefill=None,
            decode=None,
            cascade_wrapper=None,
910
911
        )

912
913
914
        # Guard access to seq_lens_cpu, which may not always be needed
        # and can be expensive to retrieve in async mode.
        needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
915
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
        seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
        num_blocks_np = (
            (seq_lens_np + (page_size - 1)) // page_size
            if seq_lens_np is not None
            else None
        )

        # Adjust seq_lens_cpu for DCP
        if self.use_dcp:
            assert seq_lens_cpu is not None
            if num_prefills > 0:
                qo_indptr_prefill_cpu = (
                    qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
                )
                query_lens_prefill_cpu = (
                    qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
                )
                seq_lens_cpu[num_decodes:] = (
                    seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
                )

            seq_lens_cpu = get_dcp_local_seq_lens(
                seq_lens_cpu,
                self.dcp_world_size,
                self.dcp_rank,
                self.dcp_kv_cache_interleave_size,
            )

        # Adjust num_block_np for cascade attention
        if use_cascade:
            assert num_blocks_np is not None
            assert common_prefix_len % page_size == 0
            num_common_kv_blocks = common_prefix_len // page_size
            num_blocks_np -= num_common_kv_blocks

        # Compute paged_kv_indices if necessary
        needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode
        if needs_paged_kv_indices:
            assert num_blocks_np is not None
            assert seq_lens_np is not None
            paged_kv_indices = self._compute_flashinfer_kv_metadata(
                num_blocks_np,
                seq_lens_np,
                block_table_tensor,
                num_reqs,
                page_size,
            )
        else:
            paged_kv_indices = None

        # Early-out for cascade attention
        if use_cascade:
968
            assert num_blocks_np is not None
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
            # Grab the blocks of the shared prefix from the first request.
            num_common_kv_blocks = common_prefix_len // page_size

            # Create CPU versions directly for cascade (no GPU versions needed)
            shared_qo_indptr_cpu = torch.tensor(
                [0, num_actual_tokens], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indptr_cpu = torch.tensor(
                [0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
            shared_kv_last_page_len_cpu = torch.tensor(
                [page_size], dtype=torch.int32, device="cpu"
            )

            # Remove the blocks of the shared prefix from all requests.
            block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
            num_blocks_np -= num_common_kv_blocks

            assert paged_kv_indices is not None
            paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[: 1 + num_reqs]
            paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs]
991

992
993
            attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
            attn_metadata.cascade_wrapper.plan(
994
995
996
997
998
999
1000
1001
1002
1003
1004
                qo_indptr_arr=[shared_qo_indptr_cpu, qo_indptr_cpu],
                paged_kv_indptr_arr=[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
                paged_kv_indices_arr=[shared_kv_page_indices_cpu, paged_kv_indices],
                paged_kv_last_page_len=[
                    shared_kv_last_page_len_cpu,
                    paged_kv_last_page_len_cpu,
                ],
                num_qo_heads=self.num_qo_heads,
                num_kv_heads=self.num_kv_heads,
                head_dim=self.head_dim,
                page_size=self.page_size,
1005
                causal=True,
1006
1007
1008
                sm_scale=self.sm_scale,
                window_left=self.window_left,
                logits_soft_cap=self.logits_soft_cap,
1009
1010
1011
                q_data_type=self.q_data_type,
                kv_data_type=self.kv_cache_dtype,
            )
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
            return attn_metadata

        # Step 3: Handle prefill and decode pathways case by case
        ## PREFILL PATHWAY
        if num_prefills > 0:
            # Slices for shared prefill metadata
            prefill_start = num_decodes
            qo_indptr_prefill_cpu = (
                qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
            )
            assert qo_indptr_prefill_cpu.shape[0] == num_prefills + 1

            if prefill_use_trtllm:
                # Create GPU versions
                qo_indptr_prefill_gpu = (
                    qo_indptr[prefill_start:] - qo_indptr[prefill_start]
1028
                )
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
                paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
                    prefill_start : num_reqs + 1
                ]
                # Compute max_q_len for prefill requests
                query_lens_prefill_cpu = (
                    qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
                )
                max_q_len_prefill = int(query_lens_prefill_cpu.max().item())
                attn_metadata.prefill = TRTLLMPrefill(
                    block_tables=block_table_tensor[prefill_start:],
                    seq_lens=seq_lens[prefill_start:],
                    cum_seq_lens_q=qo_indptr_prefill_gpu,
                    cum_seq_lens_kv=paged_kv_indptr_prefill_gpu,
                    max_q_len=max_q_len_prefill,
                    max_seq_len=max_seq_len,
1044
                )
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
            else:
                prefill_wrapper = self._get_prefill_wrapper()
                # Slicing CPU buffers that are only needed for FI native prefills
                paged_kv_last_page_len_prefill_cpu = self.paged_kv_last_page_len.cpu[
                    prefill_start:num_reqs
                ]
                assert paged_kv_last_page_len_prefill_cpu.shape[0] == num_prefills
                paged_kv_indptr_prefill_cpu = self.paged_kv_indptr.cpu[
                    prefill_start : num_reqs + 1
                ]
                assert paged_kv_indptr_prefill_cpu.shape[0] == num_prefills + 1
                if self.use_dcp:
                    assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
                    prefill_wrapper.plan(
                        qo_indptr_cpu=qo_indptr_prefill_cpu,
                        paged_kv_indptr_cpu=paged_kv_indptr_prefill_cpu,
                        paged_kv_indices=paged_kv_indices,
                        paged_kv_last_page_len_cpu=paged_kv_last_page_len_prefill_cpu,
                        page_size=self.page_size,
                        num_qo_heads=self.num_qo_heads,
                        dcp_world_size=self.dcp_world_size,
                        num_kv_heads=self.num_kv_heads,
                        head_dim=self.head_dim,
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
                        q_data_type=self.q_data_type,
                        kv_cache_dtype=self.kv_cache_dtype,
                        prefill_fixed_split_size=self.prefill_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
                    )
1076
                else:
1077
1078
1079
                    assert isinstance(
                        prefill_wrapper,
                        BatchPrefillWithPagedKVCacheWrapper,
1080
                    )
1081
                    prefill_wrapper.plan(
1082
1083
1084
1085
1086
1087
1088
1089
                        qo_indptr=qo_indptr_prefill_cpu,
                        paged_kv_indptr=paged_kv_indptr_prefill_cpu,
                        paged_kv_indices=paged_kv_indices,
                        paged_kv_last_page_len=paged_kv_last_page_len_prefill_cpu,
                        num_qo_heads=self.num_qo_heads,
                        num_kv_heads=self.num_kv_heads,
                        head_dim_qk=self.head_dim,
                        page_size=self.page_size,
1090
1091
1092
1093
1094
1095
                        causal=True,
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
                        q_data_type=self.q_data_type,
                        kv_data_type=self.kv_cache_dtype,
1096
                        o_data_type=self.model_config.dtype,
1097
1098
                        fixed_split_size=self.prefill_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
1099
                    )
1100
                attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper)
1101

1102
1103
1104
1105
        ## DECODE PATHWAY
        if num_decodes > 0:
            if decode_use_trtllm:
                assert num_decode_tokens % num_decodes == 0, (
1106
1107
                    "TRTLLM decode requires uniform query lengths per request. "
                    f"Got {num_decode_tokens=} and {num_decodes=}."
1108
1109
1110
1111
1112
1113
1114
                )
                attn_metadata.decode = TRTLLMDecode(
                    block_tables=block_table_tensor[:num_decodes],
                    seq_lens=seq_lens[:num_decodes],
                    max_seq_len=max_seq_len,
                )
            else:
1115
                assert seq_lens_cpu is not None
1116
                pure_decode = num_prefills == 0
1117
1118
1119
                use_cudagraph = (
                    self.enable_cuda_graph
                    and pure_decode
1120
                    and num_decode_tokens <= self._decode_cudagraph_max_bs
1121
                )
1122
                num_input_tokens = num_decode_tokens
1123

1124
                decode_wrapper = self._get_decode_wrapper(
1125
1126
                    num_input_tokens, use_cudagraph
                )
1127
1128
1129
1130
1131
                # Use the persistent buffer with padding length,
                # instead of the same address but chunked version
                # in atten_metadata when using cudagraph.
                fast_plan_decode(
                    decode_wrapper,
1132
1133
1134
1135
1136
1137
1138
1139
1140
                    indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
                    indices=paged_kv_indices,
                    last_page_len_cpu=self.paged_kv_last_page_len.cpu[
                        :num_input_tokens
                    ],
                    num_qo_heads=self.num_qo_heads * self.dcp_world_size,
                    num_kv_heads=self.num_kv_heads,
                    head_dim=self.head_dim,
                    page_size=self.page_size,
1141
1142
1143
1144
1145
1146
1147
                    # Disable flashinfer's pos encoding and use vllm's rope.
                    pos_encoding_mode="NONE",
                    sm_scale=self.sm_scale,
                    window_left=self.window_left,
                    logits_soft_cap=self.logits_soft_cap,
                    q_data_type=self.q_data_type,
                    kv_data_type=self.kv_cache_dtype,
1148
                    o_data_type=self.model_config.dtype,
1149
1150
1151
1152
                    fixed_split_size=self.decode_fixed_split_size,
                    disable_split_kv=self.disable_split_kv,
                )
                attn_metadata.decode = FIDecode(wrapper=decode_wrapper)
1153
1154
1155
        return attn_metadata

    def use_cascade_attention(self, *args, **kwargs) -> bool:
1156
        if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
1157
1158
1159
            # TODO: The cascade wrapper currently does not support setting
            # kv cache dtype to something different from query dtype.
            return False
1160
1161
1162
        # TODO: Cascade attention doesn't work, disable it for now
        # return use_cascade_attention(*args, **kwargs)
        return False
1163
1164
1165


class FlashInferImpl(AttentionImpl):
1166
1167
    can_return_lse_for_decode: bool = True

1168
1169
1170
1171
1172
1173
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
1174
1175
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
1176
        kv_cache_dtype: str,
1177
        logits_soft_cap: float | None = None,
1178
        attn_type: AttentionType = AttentionType.DECODER,
1179
1180
        kv_sharing_target_layer_name: int | None = None,
        sinks: torch.Tensor | None = None,
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
1193
1194
1195
        self.window_left = (
            self.sliding_window[0] if self.sliding_window is not None else -1
        )
1196
1197
        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
1198
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
1199
1200
1201
1202

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if attn_type != AttentionType.DECODER:
1203
1204
1205
1206
1207
1208
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashInferImpl"
            )
1209

1210
        self.sinks: torch.Tensor | None = None
1211
        if sinks is not None:
1212
1213
1214
1215
            if sinks.shape[0] != num_heads:
                raise ValueError(
                    "Sinks must have the same number of heads as the number of "
                    f"heads in the layer. Expected {num_heads}, but got "
1216
1217
                    f"{sinks.shape[0]}."
                )
1218
1219
            self.sinks = sinks

1220
        self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
1221
        vllm_config = get_current_vllm_config_or_none()
1222
1223
        self.supports_quant_query_input = (
            self.support_trtllm_attn
1224
            and vllm_config is not None
1225
1226
            and not vllm_config.attention_config.disable_flashinfer_q_quantization
        )
1227
1228
1229
        self.bmm1_scale: float | None = None
        self.bmm2_scale: float | None = None
        self.o_sf_scale: float | None = None
1230

1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
        dcp_a2a = (
            vllm_config is not None
            and vllm_config.parallel_config.decode_context_parallel_size > 1
            and vllm_config.parallel_config.dcp_comm_backend == "a2a"
        )
        if dcp_a2a:
            self.dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
        else:
            self.dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)

1241
    def fused_output_quant_supported(self, quant_key: QuantKey):
1242
1243
1244
        return (
            self.support_trtllm_attn
            and self.kv_cache_dtype.startswith("fp8")
1245
            and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
1246
        )
1247

1248
1249
1250
1251
1252
    # FlashInfer requires attention sinks to be float32
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if self.sinks is not None and self.sinks.dtype != torch.float32:
            self.sinks = self.sinks.to(torch.float32)

1253
1254
1255
1256
1257
1258
1259
1260
    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashInferMetadata,
1261
1262
1263
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
1264
1265
1266
1267
1268
1269
1270
    ) -> torch.Tensor:
        """Forward pass with FlashInfer.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
1271
1272
1273
            kv_cache: KV cache tensor with different possible shapes:
                - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
                - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
1274
1275
1276
1277
1278
1279
1280
1281
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if attn_metadata is None:
            # Profiling run.
1282
            return output.fill_(0)
1283

1284
1285
1286
1287
1288
1289
        # Ensure query dtype matches the expected dtype from attention metadata
        assert attn_metadata.q_data_type == query.dtype, (
            f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
            f"got {query.dtype}"
        )

1290
        if self.bmm1_scale is None:
1291
            self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
1292
1293
1294
1295

        if self.bmm2_scale is None:
            self.bmm2_scale = layer._v_scale_float

1296
1297
1298
        prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
        decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode)

1299
1300
        # The attn+quant fusion happens when output_scale is provided.
        if output_scale is None:
1301
1302
1303
            assert output_block_scale is None, (
                "output_block_scale is not supported when fusion has not happened"
            )
1304
        else:
1305
            assert attn_metadata.q_data_type == FP8_DTYPE, (
1306
                "Query must be FP8 when attn+quant fusion happened."
1307
            )
1308
1309
            assert (attn_metadata.num_prefills == 0 or prefill_use_trtllm) and (
                attn_metadata.num_decodes == 0 or decode_use_trtllm
1310
            ), "Must use TRT-LLM attn"
1311

1312
            if output.dtype == FP8_DTYPE:
1313
                assert output_block_scale is None, (
1314
                    "output_block_scale should not be provided for fp8 output"
1315
                )
1316
            elif output.dtype == FP4_DTYPE:
1317
                assert output_block_scale is not None, (
1318
                    "output_block_scale is required for nvfp4 output"
1319
                )
1320
1321
1322
            else:
                raise ValueError(f"Unsupported output dtype: {output.dtype}")

1323
            # TRTLLM attn kernel requires to scale to pass as a host scalar,
1324
1325
            # store the o scale as a host scalar in warmup run with cuda graph
            # not enabled
1326
1327
            if layer._o_scale_float is None:
                layer._o_scale_float = output_scale.cpu().item()
1328
1329
1330
1331
                if output.dtype == FP8_DTYPE:
                    self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
                elif output.dtype == FP4_DTYPE:
                    self.o_sf_scale = layer._o_scale_float
1332

1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens
1343

1344
1345
1346
1347
1348
1349
1350
        # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
        # to process the cache when the kv_cache_dtype is fp8
        if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
            "fp8"
        ):
            torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.kv_cache_dtype
1351
            )
1352
            kv_cache = kv_cache.view(torch_dtype)
1353

1354
1355
        # Inputs and outputs may be padded for CUDA graphs
        query = query[:num_actual_tokens]
1356
1357
        key = key[:num_actual_tokens]
        value = value[:num_actual_tokens]
1358
1359
1360
1361
1362
1363
1364
1365
1366
        output_padded = output
        output = output[:num_actual_tokens]

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

1367
1368
        # When using spec decoding, num_decodes can be < num_decode_tokens
        # because some decode requests may have more than one query token.
1369
1370
1371
        num_decode_tokens = attn_metadata.num_decode_tokens
        num_prefill_tokens = attn_metadata.num_prefill_tokens

1372
        stride_order = FlashInferBackend.get_kv_cache_stride_order()
1373
        kv_cache_permute = kv_cache.permute(*stride_order)
1374
1375
1376

        use_dcp = self.dcp_world_size > 1

1377
        # Regular attention (common case).
1378
        # Decodes are at the front and prefills are at the back.
1379
        if num_prefill_tokens > 0:
1380
1381
            prefill_query = query[num_decode_tokens:]
            assert prefill_query.shape[0] == num_prefill_tokens
1382

1383
1384
1385
1386
1387
            if not prefill_use_trtllm:
                assert isinstance(attn_metadata.prefill, FIPrefill)
                prefill_wrapper = attn_metadata.prefill.wrapper
                assert prefill_wrapper is not None
                if use_dcp:
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
                    assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
                    assert prefill_wrapper._context._window_left == self.window_left
                    assert prefill_wrapper._context._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._context._sm_scale == self.scale
                    assert not prefill_wrapper._context._causal
                    assert prefill_wrapper._new_tokens._window_left == self.window_left
                    assert prefill_wrapper._new_tokens._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._new_tokens._sm_scale == self.scale
                    assert prefill_wrapper._new_tokens._causal

                    prefill_wrapper.run(
                        layer,
                        prefill_query,
                        kv_cache_permute,
                        key[num_decode_tokens:],
                        value[num_decode_tokens:],
                        out=output[num_decode_tokens:],
                    )
                else:
                    assert isinstance(
                        prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper
                    )
                    assert prefill_wrapper._window_left == self.window_left
                    assert prefill_wrapper._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._sm_scale == self.scale
                    assert prefill_wrapper._causal
                    prefill_wrapper.run(
                        prefill_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output[num_decode_tokens:],
                    )
1427
            else:
1428
                assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
1429
1430
1431
1432
1433
                # prefill_query may be non-contiguous or have degenerate strides
                # First ensure memory contiguity, then fix degenerate strides
                # with reshape. contiguous() alone doesn't fix degenerate
                # strides when a dimension has size 1.
                prefill_query = prefill_query.contiguous().reshape(prefill_query.shape)
1434
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
1435
1436
                block_tables_prefill = attn_metadata.prefill.block_tables
                seq_lens_prefill = attn_metadata.prefill.seq_lens
1437
1438
1439

                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                assert get_kv_cache_layout() == "HND"
1440
1441
1442
1443
                assert is_strictly_contiguous(prefill_query)
                assert is_strictly_contiguous(workspace_buffer)
                assert is_strictly_contiguous(block_tables_prefill)
                assert is_strictly_contiguous(seq_lens_prefill)
1444

1445
1446
                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
1447
1448
1449
1450
1451
1452
                    out = FP4Tensor(
                        data=output[num_decode_tokens:],
                        scale=output_block_scale,
                        scale_start_index=num_decode_tokens,
                        original_shape=prefill_query.shape,
                    )
1453
1454
1455
1456
                else:
                    assert self.o_sf_scale is None
                    out = output[num_decode_tokens:]

1457
1458
1459
1460
                if (
                    attn_metadata.q_data_type != FP8_DTYPE
                    and self.kv_cache_dtype.startswith("fp8")
                ):
1461
1462
1463
1464
                    # TRTLLM prefill attention does not support BF16 Q
                    # and fp8 kv cache. So to enable prefill attention
                    # with fp8 kv cache, we can construct a mock block
                    # and mock kv cache with BF16 KV involved in the prefill
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
                    #
                    # The inner (block_size, head_size) dims must be
                    # contiguous; outer dims may have non-canonical strides
                    # (e.g. cross-layer unified allocation).
                    # Degenerate strides on outer dims break TMA descriptors
                    # (see flashinfer-ai/flashinfer#2232).
                    kv_strides = kv_cache_permute.stride()
                    assert (
                        kv_strides[-1] == 1
                        and kv_strides[-2] == kv_cache_permute.shape[-1]
                    ), (
                        "KV cache inner dims (block_size, head_size) must be "
                        f"contiguous, got strides {kv_strides}"
                    )
1479
1480
1481
1482
1483
1484
1485
                    mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
                        kv_cache_permute,
                        block_tables_prefill,
                        layer._k_scale,
                        layer._v_scale,
                        attn_metadata.q_data_type,
                    )
1486
1487
1488
1489
                else:
                    mock_kv_cache = kv_cache_permute
                    mock_block_table = block_tables_prefill

1490
1491
                trtllm_batch_context_with_kv_cache(
                    query=prefill_query,
1492
                    kv_cache=mock_kv_cache,
1493
                    workspace_buffer=workspace_buffer,
1494
                    block_tables=mock_block_table,
1495
                    seq_lens=seq_lens_prefill,
1496
1497
                    max_q_len=attn_metadata.prefill.max_q_len,
                    max_kv_len=attn_metadata.prefill.max_seq_len,
1498
1499
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
1500
                    batch_size=attn_metadata.num_prefills,
1501
1502
                    cum_seq_lens_q=attn_metadata.prefill.cum_seq_lens_q,
                    cum_seq_lens_kv=attn_metadata.prefill.cum_seq_lens_kv,
1503
                    window_left=self.window_left,
1504
                    sinks=self.sinks,
1505
1506
                    o_sf_scale=self.o_sf_scale,
                    out=out,
1507
1508
1509
                )

        if num_decode_tokens > 0:
1510
1511
            decode_query = query[:num_decode_tokens]
            assert decode_query.shape[0] == num_decode_tokens
1512

1513
1514
1515
1516
            if not decode_use_trtllm:
                assert isinstance(attn_metadata.decode, FIDecode)
                decode_wrapper = attn_metadata.decode.wrapper
                assert decode_wrapper is not None
1517
                assert decode_wrapper._window_left == self.window_left
1518
                assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
1519
                assert decode_wrapper._sm_scale == self.scale
1520

1521
                if use_dcp:
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
                    decode_query = get_dcp_group().all_gather(
                        decode_query.contiguous(), dim=-2
                    )
                    output_tmp = torch.empty_like(decode_query)
                    lse = torch.empty(
                        (decode_query.size(0), decode_query.size(1)),
                        dtype=torch.float32,
                        device=decode_query.device,
                    )
                    decode_wrapper.run(
                        decode_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output_tmp,
                        lse=lse,
                        return_lse=True,
                    )
1540
                    output[:num_decode_tokens] = self.dcp_combine(
1541
1542
1543
                        output_tmp,
                        lse,
                        get_dcp_group(),
1544
1545
1546
1547
1548
1549
1550
1551
1552
                    )
                else:
                    decode_wrapper.run(
                        decode_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output[:num_decode_tokens],
                    )
1553
            else:
1554
                # decode_query may be non-contiguous or have degenerate strides
1555
                assert isinstance(attn_metadata.decode, TRTLLMDecode)
1556
1557
1558
1559
                # First ensure memory contiguity, then fix degenerate strides
                # with reshape. contiguous() alone doesn't fix degenerate
                # strides when a dimension has size 1.
                decode_query = decode_query.contiguous().reshape(decode_query.shape)
1560
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
1561
1562
                block_tables_decode = attn_metadata.decode.block_tables
                seq_lens_decode = attn_metadata.decode.seq_lens
1563

1564
                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
1565
                assert get_kv_cache_layout() == "HND"
1566
1567
1568
1569
                assert is_strictly_contiguous(decode_query)
                assert is_strictly_contiguous(workspace_buffer)
                assert is_strictly_contiguous(block_tables_decode)
                assert is_strictly_contiguous(seq_lens_decode)
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
                # kv_cache outer dims may be non-contiguous (e.g.
                # cross-layer unified allocation), but inner dims
                # (block_size, head_size) must be contiguous and
                # strides must be canonical to avoid TMA descriptor
                # failures (see flashinfer-ai/flashinfer#2232).
                kv_strides = kv_cache_permute.stride()
                assert (
                    kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1]
                ), (
                    "KV cache inner dims (block_size, head_size) must be "
                    f"contiguous, got strides {kv_strides}"
                )
1582

1583
1584
                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
1585
1586
1587
1588
1589
1590
                    out = FP4Tensor(
                        data=output[:num_decode_tokens],
                        scale=output_block_scale,
                        scale_start_index=0,
                        original_shape=decode_query.shape,
                    )
1591
1592
1593
1594
                else:
                    assert self.o_sf_scale is None
                    out = output[:num_decode_tokens]

1595
1596
1597
1598
1599
                if num_decode_tokens % attn_metadata.num_decodes != 0:
                    # This gets triggered when the dummy_run forces
                    # attention to be initialized with q_len = 0
                    q_len_per_req = 1
                else:
1600
                    q_len_per_req = num_decode_tokens // attn_metadata.num_decodes
1601

1602
1603
1604
1605
1606
1607
                trtllm_batch_decode_with_kv_cache(
                    query=decode_query,
                    kv_cache=kv_cache_permute,
                    workspace_buffer=workspace_buffer,
                    block_tables=block_tables_decode,
                    seq_lens=seq_lens_decode,
1608
                    max_seq_len=attn_metadata.decode.max_seq_len,
1609
1610
1611
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
                    window_left=self.window_left,
1612
                    sinks=self.sinks,
1613
1614
                    o_sf_scale=self.o_sf_scale,
                    out=out,
1615
1616
                    q_len_per_req=q_len_per_req,
                )
1617
        return output_padded
1618

1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
    def do_kv_cache_update(
        self,
        layer: torch.nn.Module,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> None:
        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657

def fast_plan_decode(
    self,  # decode wrapper
    indptr_cpu: torch.Tensor,
    indices: torch.Tensor,
    last_page_len_cpu: torch.Tensor,
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    window_left: int = -1,
1658
    logits_soft_cap: float | None = None,
1659
1660
    q_data_type: str | torch.dtype | None = "float16",
    kv_data_type: str | torch.dtype | None = None,
1661
    o_data_type: str | torch.dtype | None = None,
1662
    data_type: str | torch.dtype | None = None,
1663
1664
1665
    sm_scale: float | None = None,
    rope_scale: float | None = None,
    rope_theta: float | None = None,
1666
    non_blocking: bool = True,
1667
1668
    fixed_split_size: int = -1,
    disable_split_kv: bool = False,
1669
1670
) -> None:
    """
1671
1672
    A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
    cudagraph capture/replay, while the no cudagraph version turns back
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
    to the original plan.
    using original plan after passing host-side buffers:
    - only host-to-device copy of indptr and last_page_len buffers
    Modifications for cudagraph:
    - only host-to-device copy of indptr and last_page_len buffers.
    - avoid device-to-device copy of indices buffer.

    Part of the code get inspiration from the original plan from FlashInfer repo
    and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
    """
    # Warm up with the original plan if it is first call, and always run the
    # original plan if we run for dynamic shape. For fixed shape (cudagraph),
    # this warm up is to generate the _cached_module for the decode wrapper.
1686
    if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
1687
        self.plan(
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
            indptr=indptr_cpu,
            indices=indices,
            last_page_len=last_page_len_cpu,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim=head_dim,
            page_size=page_size,
            pos_encoding_mode=pos_encoding_mode,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
            kv_data_type=kv_data_type,
            o_data_type=o_data_type,
            data_type=data_type,
            sm_scale=sm_scale,
            rope_scale=rope_scale,
            rope_theta=rope_theta,
            non_blocking=non_blocking,
            block_tables=None,
            seq_lens=None,
            fixed_split_size=fixed_split_size,
            disable_split_kv=disable_split_kv,
1710
1711
1712
1713
1714
1715
        )
        self.vllm_first_call = False
        return

    assert self.is_cuda_graph_enabled, "Should be cudagraph only here"

1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
    fast_decode_plan(
        self,
        indptr=indptr_cpu,
        indices=indices,
        last_page_len=last_page_len_cpu,
        num_qo_heads=num_qo_heads,
        num_kv_heads=num_kv_heads,
        head_dim=head_dim,
        page_size=page_size,
        pos_encoding_mode=pos_encoding_mode,
        window_left=window_left,
        logits_soft_cap=logits_soft_cap,
        q_data_type=q_data_type,
        kv_data_type=kv_data_type,
        data_type=data_type,
        sm_scale=sm_scale,
        rope_scale=rope_scale,
        rope_theta=rope_theta,
        non_blocking=non_blocking,
        fixed_split_size=fixed_split_size,
        disable_split_kv=disable_split_kv,
1737
    )
1738

1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756

@triton.jit
def _copy_page_indices_kernel(
    page_indices,
    block_table,
    block_table_stride,
    cu_num_blocks,
    BLOCK_SIZE: tl.constexpr,
):
    req_idx = tl.program_id(0)
    row_ptr = block_table + req_idx * block_table_stride
    start_idx = tl.load(cu_num_blocks + req_idx)
    end_idx = tl.load(cu_num_blocks + req_idx + 1)
    num_blocks = end_idx - start_idx

    offset = tl.arange(0, BLOCK_SIZE)
    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
1757
1758
1759
1760
1761
        tl.store(
            page_indices + start_idx + i + offset,
            block_ids,
            mask=i + offset < num_blocks,
        )