flashinfer.py 51.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer with FlashInfer."""
4

5
from dataclasses import dataclass
6
from typing import ClassVar
7

8
import numpy as np
9
import torch
10
11
12
13
14
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
    MultiLevelCascadeAttentionWrapper,
)
15
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
16
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
17
from flashinfer.utils import FP4Tensor
18

19
20
21
22
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionImpl,
    AttentionType,
23
    MultipleOf,
24
)
25
from vllm.config import CUDAGraphMode, VllmConfig
26
from vllm.logger import init_logger
27
from vllm.model_executor.layers.batch_invariant import (
28
    vllm_is_batch_invariant,
29
)
30
from vllm.model_executor.layers.quantization.utils.quant_utils import (
31
32
33
34
    QuantKey,
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
35
from vllm.platforms import current_platform
36
from vllm.triton_utils import tl, triton
37
38
39
40
41
from vllm.utils.flashinfer import (
    can_use_trtllm_attention,
    flashinfer_disable_q_quantization,
    use_trtllm_attention,
)
42
from vllm.utils.math_utils import cdiv
43
from vllm.utils.platform_utils import is_pin_memory_available
44
45
46
47
48
49
50
51
52
from vllm.v1.attention.backends.utils import (
    AttentionCGSupport,
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
    get_kv_cache_layout,
    get_per_layer_parameters,
    infer_global_hyperparameters,
    split_decodes_and_prefills,
)
53
from vllm.v1.kv_cache_interface import AttentionSpec
54
55

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
56
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
57

58
FP8_DTYPE = current_platform.fp8_dtype()
59
FP4_DTYPE = torch.uint8
60

61
62
logger = init_logger(__name__)

63
64
65
66
67
68
69
trtllm_gen_workspace_buffer = None


def _get_trtllm_gen_workspace_buffer():
    global trtllm_gen_workspace_buffer
    if trtllm_gen_workspace_buffer is None:
        trtllm_gen_workspace_buffer = torch.zeros(
70
71
            FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
        )
72
73
    return trtllm_gen_workspace_buffer

74

75
76
77
78
79
80
81
82
83
84
85
86
87
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
    kv_cache_ptr,
    block_tables_prefill_ptr,
    block_table_stride,
    mock_kv_cache_ptr,
    k_scale_ptr,
    v_scale_ptr,
    K_CACHE_STRIDE: tl.constexpr,
    KV_CACHE_STRIDE: tl.constexpr,
):
    batch_idx = tl.program_id(0).to(tl.int64)
    mock_block_table_idx = tl.program_id(1).to(tl.int64)
88
89
90
    orig_page_num = tl.load(
        block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx
    ).to(tl.int64)
91
92
93
94
95
96
97
98
99
    if orig_page_num <= 0:
        return
    dequant_dtype = mock_kv_cache_ptr.dtype.element_ty

    # Dequantize K
    k_scale_val = tl.load(k_scale_ptr)
    offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
100
101
102
    mock_cache_offset = (
        batch_idx * block_table_stride + mock_block_table_idx + 1
    ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
103
104
105
106
107
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)

    # Dequantize V
    v_scale_val = tl.load(v_scale_ptr)
108
109
110
    offset = (
        orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    )
111
112
113
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
    mock_cache_offset = (
114
115
116
117
        (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE
        + K_CACHE_STRIDE
        + tl.arange(0, K_CACHE_STRIDE)
    )
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)


def trtllm_prefill_attn_kvfp8_dequant(
    kv_cache: torch.Tensor,
    block_tables_prefill: torch.Tensor,
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    dequant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
    batch_size, num_of_page_per_token = block_tables_prefill.shape
    s = kv_cache.shape
    assert s[1] == 2
    assert dequant_dtype in (torch.bfloat16, torch.float16)
    k_cache_stride = s[2] * s[3] * s[4]
    kv_cache_stride = k_cache_stride * s[1]
    new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
    # mock kv cache contains just the pages needed by this prefill
137
    mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    # we simply sequentially index the pages needed by this prefill
    mock_block_table = torch.arange(
        start=1,
        end=batch_size * num_of_page_per_token + 1,
        dtype=torch.int32,
        device=block_tables_prefill.device,
    ).reshape(batch_size, num_of_page_per_token)
    grid = (batch_size, num_of_page_per_token)
    _trtllm_prefill_attn_kvfp8_dequant[grid](
        kv_cache,
        block_tables_prefill,
        num_of_page_per_token,
        mock_kv_cache,
        k_scale,
        v_scale,
        k_cache_stride,
        kv_cache_stride,
    )
    return mock_kv_cache, mock_block_table

158

159
class FlashInferBackend(AttentionBackend):
160
161
    accept_output_buffer: bool = True

162
163
164
165
    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

166
167
168
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
169
170
        return [64, 128, 256]

171
    @staticmethod
172
    def get_supported_kernel_block_size() -> list[int | MultipleOf]:
173
174
175
176
177
        # Note: Not sure for all platforms,
        # but on Blackwell, only support a page size of
        # 16, 32, 64
        return [16, 32, 64]

178
179
180
181
182
183
184
185
186
    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
187
188
                "FlexAttention backend which supports all head sizes."
            )
189

190
191
    @staticmethod
    def get_name() -> str:
192
        return "FLASHINFER"
193
194

    @staticmethod
195
    def get_impl_cls() -> type["FlashInferImpl"]:
196
197
198
        return FlashInferImpl

    @staticmethod
199
    def get_metadata_cls() -> type["FlashInferMetadata"]:
200
201
202
        return FlashInferMetadata

    @staticmethod
203
    def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
204
205
206
207
208
209
210
211
        return FlashInferMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
212
        cache_dtype_str: str = "auto",
213
214
215
    ) -> tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

216
217
218
219
220
221
222
223
224
225
226
227
228
    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
        # `stride_order` indicates the permutation that gets us from
        # `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
        if cache_layout == "NHD":
            stride_order = (0, 1, 2, 3, 4)
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

229
230
231
232
233
234
235
236
237
    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

238
239
240
241
242
243
244
245
246
247

@dataclass
class FlashInferMetadata:
    num_actual_tokens: int  # Number of tokens excluding padding.

    # The data type of the query
    q_data_type: torch.dtype

    slot_mapping: torch.Tensor

248
    # For flashinfer trtllm batch decode
249
    max_q_len: int
250
    max_q_len_prefill: int
251
252
253
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table_tensor: torch.Tensor
254
255
    prefill_use_trtllm: bool
    decode_use_trtllm: bool
256

257
258
259
260
261
262
    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

263
    # For cascade attention (CPU for planning).
264
265
    use_cascade: bool

266
267
268
    prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None
    decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
    cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
269

270
271
    qo_indptr_gpu: torch.Tensor | None = None
    paged_kv_indptr_gpu: torch.Tensor | None = None
272

273

274
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
275
    cudagraph_support: ClassVar[AttentionCGSupport] = (
276
        AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
277
    )
278

279
    reorder_batch_threshold: int = 1
280

281
282
283
284
285
286
287
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
288
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
289
        self.cache_config = vllm_config.cache_config
290
        self.model_config = vllm_config.model_config
291
292
        self._workspace_buffer = None
        self._prefill_wrapper = None  # Wrapper for prefill/append
293
294
        self._decode_wrapper = None  # Wrapper for decode (general shape)

295
        if vllm_is_batch_invariant():
296
297
298
299
300
301
302
303
            self.decode_fixed_split_size = 2048
            self.prefill_fixed_split_size = 4096
            self.disable_split_kv = True
        else:
            self.decode_fixed_split_size = -1
            self.prefill_fixed_split_size = -1
            self.disable_split_kv = False

304
        self.compilation_config = vllm_config.compilation_config
305
306
307
        max_num_pages_per_req = cdiv(
            self.model_config.max_model_len, self.kv_cache_spec.block_size
        )
308
309
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        max_num_pages = max_num_reqs * max_num_pages_per_req
310
311
312
313
314
315
        speculative_config = vllm_config.speculative_config
        num_spec_tokens = (
            speculative_config.num_speculative_tokens
            if speculative_config is not None
            else 0
        )
316
317
318
        self.enable_cuda_graph = (
            self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
        )
319
320
321
322
        if self.enable_cuda_graph:
            # For full cudagraph capture, one `decode_wrapper` for each batch
            # size is needed for FlashInfer.
            self._decode_wrappers_cudagraph: dict[
323
324
                int, BatchDecodeWithPagedKVCacheWrapper
            ] = {}
325
            self._decode_cudagraph_max_bs = min(
326
                (1 + num_spec_tokens) * max_num_reqs,
327
                self.compilation_config.max_cudagraph_capture_size,
328
            )
329

330
        self.num_qo_heads = self.model_config.get_num_attention_heads(
331
332
            self.vllm_config.parallel_config
        )
333
334
335
336
337
338
339
        self.num_kv_heads = self.kv_cache_spec.num_kv_heads
        self.head_dim = self.kv_cache_spec.head_size
        FlashInferBackend.validate_head_size(self.head_dim)
        self.page_size = self.kv_cache_spec.block_size

        self.cache_dtype = self.cache_config.cache_dtype
        if self.cache_dtype.startswith("fp8"):
340
341
342
            self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.cache_dtype
            )
343
        else:
344
            assert self.kv_cache_spec.dtype == self.model_config.dtype
345
            self.kv_cache_dtype = self.kv_cache_spec.dtype
346

347
348
349
350
        # Use model dtype as q dtype when TRTLLM attn is not supported, or
        # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
        # use fp8 q if kv cache is fp8, and will fall back to model dtype
        # if TRTLLM attention kernel is not used when building attn metadata
351
352
        can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
        if can_use_trtllm and not flashinfer_disable_q_quantization():
353
354
355
            self.q_data_type = self.kv_cache_dtype
        else:
            self.q_data_type = self.model_config.dtype
356

357
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
358

359
360
361
        self._cascade_wrapper = None  # Wrapper for cascade attention

        # Global hyperparameters shared by all attention layers
362
        # TODO: discard this for trtllm-gen backend
363
        self.global_hyperparameters = infer_global_hyperparameters(
364
365
            get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)
        )
366
367
368
369
        self.sm_scale = self.global_hyperparameters.sm_scale
        self.window_left = self.global_hyperparameters.window_left
        self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
        self.has_sinks = self.global_hyperparameters.has_sinks
370
        if self.has_sinks and not can_use_trtllm:
371
372
373
            raise NotImplementedError(
                "FlashInfer backend currently does not support attention "
                "sinks, please use trtllm on blackwell or flash attention on "
374
375
                "earlier GPUs."
            )
376
        # Preparing persistent buffers (device-side)
377
378
379
        self.paged_kv_indptr = torch.zeros(
            max_num_reqs + 1, dtype=torch.int32, device=self.device
        )
380
381
382
        self.paged_kv_indices = torch.zeros(
            max_num_pages,  # max num pages possible
            dtype=torch.int32,
383
384
385
386
387
            device=self.device,
        )
        self.paged_kv_last_page_len = torch.zeros(
            max_num_reqs, dtype=torch.int32, device=self.device
        )
388
389
        # host-side buffer
        pin_memory = is_pin_memory_available()
390
391
392
        self.paged_kv_indptr_cpu = torch.zeros(
            max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
393
        self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
394
        self.paged_kv_indptr_buffer = torch.zeros_like(
395
396
397
398
399
400
401
402
403
            self.paged_kv_indptr_cpu, pin_memory=pin_memory
        )
        self.paged_kv_indices_cpu = torch.zeros(
            max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
        self.paged_kv_last_page_len_cpu = torch.zeros(
            max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
        self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
404

405
406
407
408
409
410
411
412
413
        if self.head_dim == 256 and current_platform.is_device_capability(100):
            # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
            # head size 256 and block size 16 is not supported on blackwell.
            assert kv_cache_spec.block_size != 16, (
                "There is a bug in FlashInfer "
                "block_size 16 head size 256 support. Please avoid this combination by "
                "passing --block-size 32 or --block-size 64."
            )

414
415
    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
416
            buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
417
            if vllm_is_batch_invariant():
418
                buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
419
            self._workspace_buffer = torch.zeros(
420
                buffer_size, dtype=torch.uint8, device=self.device
421
            )
422
423
424
425
426
        return self._workspace_buffer

    def _get_prefill_wrapper(self):
        if self._prefill_wrapper is None:
            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
427
428
                self._get_workspace_buffer(), get_kv_cache_layout()
            )
429
430
        return self._prefill_wrapper

431
    def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
432
        if use_cudagraph:
433
            decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None)
434
435
436
437
438
        else:
            decode_wrapper = self._decode_wrapper

        if decode_wrapper is None:
            if use_cudagraph:
439
                paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1]
440
                paged_kv_indices = self.paged_kv_indices
441
                paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size]
442
443
444
445
446
            else:
                paged_kv_indptr = None
                paged_kv_indices = None
                paged_kv_last_page_len = None
            decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
447
                self._get_workspace_buffer(),
448
                get_kv_cache_layout(),
449
450
451
452
                use_cuda_graph=use_cudagraph,
                paged_kv_indptr_buffer=paged_kv_indptr,
                paged_kv_indices_buffer=paged_kv_indices,
                paged_kv_last_page_len_buffer=paged_kv_last_page_len,
453
                # Tensor cores are enabled by default because the perf would be
co63oc's avatar
co63oc committed
454
                # at least as good as cuda cores for all attention ops in latest
455
456
457
                # gpus.
                use_tensor_cores=True,
            )
458
459
460
461
462
463
464
465

            # save the decode wrapper
            if use_cudagraph:
                self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
            else:
                self._decode_wrapper = decode_wrapper

        return decode_wrapper
466
467
468
469

    def _get_cascade_wrapper(self):
        if self._cascade_wrapper is None:
            self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
470
471
                2, self._get_workspace_buffer(), get_kv_cache_layout()
            )
472
473
        return self._cascade_wrapper

474
475
476
477
478
479
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashInferMetadata:
480
        num_reqs = common_attn_metadata.num_reqs
481
        num_actual_tokens = common_attn_metadata.num_actual_tokens
482
483
484
485
486
487
488
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold,
                require_uniform=True,
            )
        )
489

490
        page_size = self.page_size
491
        max_q_len = common_attn_metadata.max_query_len
492
        max_seq_len = common_attn_metadata.max_seq_len
493
        seq_lens = common_attn_metadata.seq_lens
494
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
495
        seq_lens_np = seq_lens_cpu.numpy()
496
        block_table_tensor = common_attn_metadata.block_table_tensor
497

498
        num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
499
500
501
502
503
504

        use_cascade = common_prefix_len > 0
        if use_cascade:
            # Grab the blocks of the shared prefix from the first request.
            assert common_prefix_len % page_size == 0
            num_common_kv_blocks = common_prefix_len // page_size
505
506

            # Create CPU versions directly for cascade (no GPU versions needed)
507
508
509
510
511
512
513
514
515
516
            shared_qo_indptr_cpu = torch.tensor(
                [0, num_actual_tokens], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indptr_cpu = torch.tensor(
                [0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
            shared_kv_last_page_len_cpu = torch.tensor(
                [page_size], dtype=torch.int32, device="cpu"
            )
517

518
            # Remove the blocks of the shared prefix from all requests.
519
            block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
520
            num_blocks_np -= num_common_kv_blocks
521
        else:
522
523
524
525
526
            shared_qo_indptr_cpu = None
            shared_kv_page_indptr_cpu = None
            shared_kv_page_indices_cpu = None
            shared_kv_last_page_len_cpu = None

527
528
529
530
        # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
        np.cumsum(
            num_blocks_np,
            dtype=np.int32,
531
            out=self.paged_kv_indptr_np[1 : num_reqs + 1],
532
        )
533
534
535
        # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
        # after this line (e.g., for cuda graphs), we need to copy the data to
        # self.paged_kv_indptr_buffer to avoid race condition.
536
537
538
539
540
541
542
        self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
            : num_reqs + 1
        ]
        paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
        paged_kv_indptr.copy_(
            self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
        )
543

544
        # write self.paged_kv_indices inplace
545
        num_actual_pages = self.paged_kv_indptr_np[num_reqs]
546
        paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
547
        _copy_page_indices_kernel[(num_reqs,)](
548
549
550
551
552
553
            paged_kv_indices,
            block_table_tensor,
            block_table_tensor.stride(0),
            paged_kv_indptr,
            BLOCK_SIZE=1024,
        )
554

555
        # write self.paged_kv_last_page_len_cpu inplace
556
557
558
559
560
561
        paged_kv_last_page_len_np = seq_lens_np % page_size
        self.paged_kv_last_page_len_np[:num_reqs] = np.where(
            paged_kv_last_page_len_np == 0,
            page_size,
            paged_kv_last_page_len_np,
        )
562

563
        uses_spec_reorder = self.reorder_batch_threshold > 1
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        prefill_use_trtllm = use_trtllm_attention(
            self.num_qo_heads,
            self.num_kv_heads,
            num_prefill_tokens,
            max_seq_len,
            self.cache_dtype,
            self.q_data_type,
            is_prefill=True,
            has_sinks=self.has_sinks,
            has_spec=uses_spec_reorder,
        )
        decode_use_trtllm = use_trtllm_attention(
            self.num_qo_heads,
            self.num_kv_heads,
            num_decode_tokens,
            max_seq_len,
            self.cache_dtype,
            self.q_data_type,
            is_prefill=False,
            has_sinks=self.has_sinks,
            has_spec=uses_spec_reorder,
        )
586
587

        if not (prefill_use_trtllm and decode_use_trtllm):
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
            if self.has_sinks:
                raise NotImplementedError(
                    "FlashInfer backend currently does not support attention "
                    "sinks, please use trtllm on blackwell or flash attention "
                    "on earlier GPUs."
                )

            if not self.global_hyperparameters.has_same_window_lefts:
                raise ValueError(
                    "Window left is not the same for all layers. "
                    "One potential fix is to set disable_sliding_window=True"
                )

            assert self.global_hyperparameters.has_same_all_params, (
                "FlashInfer backend currently only supports models in which "
                "all layers share the same values for the following "
                "hyperparameters: `window_left`, `logits_soft_cap`, "
                "`sm_scale`."
            )

            # The q quantization is not supported for non-trtllm attention,
            # fall back to model dtype.
610
611
            self.q_data_type = self.model_config.dtype

612
613
        attn_metadata = FlashInferMetadata(
            num_actual_tokens=num_actual_tokens,
614
            q_data_type=self.q_data_type,
615
            slot_mapping=common_attn_metadata.slot_mapping,
616
            max_q_len=max_q_len,
617
            max_q_len_prefill=max_q_len,
618
619
620
621
622
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
            block_table_tensor=block_table_tensor,
            prefill_use_trtllm=prefill_use_trtllm,
            decode_use_trtllm=decode_use_trtllm,
623
624
625
626
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
627
628
629
            use_cascade=use_cascade,
        )

630
        qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
631
        paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs]
632
        paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
633

634
635
636
637
638
639
640
641
642
643
644
645
        if attn_metadata.use_cascade:
            attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
            attn_metadata.cascade_wrapper.plan(
                [shared_qo_indptr_cpu, qo_indptr_cpu],
                [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
                [shared_kv_page_indices_cpu, paged_kv_indices],
                [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
                self.page_size,
                causal=True,
646
647
648
                sm_scale=self.sm_scale,
                window_left=self.window_left,
                logits_soft_cap=self.logits_soft_cap,
649
650
651
652
653
                q_data_type=self.q_data_type,
                kv_data_type=self.kv_cache_dtype,
            )
        else:
            # Regular attention (common case).
654
            # Decodes are at the front and prefills are at the back.
655
656
657
658
659
660
            num_prefills = attn_metadata.num_prefills
            num_decodes = attn_metadata.num_decodes
            if num_prefills > 0:
                # Decodes are first so prefills start after the last decode
                prefill_start = num_decodes
                attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
661
662
663
664
665
                assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
                assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
                assert (
                    paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills
                )
666
667
668
                # Since prefill_wrapper.run() will be called with
                # query[num_decode_tokens:] we need to adjust the qo_indptr
                # to be relative to the start of the prefill queries.
669
670
671
                qo_indptr_cpu = (
                    qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
                )
672
                paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
673
674
675
676
677
678

                # Recompute max_q_len for the slice of requests we are using
                # for prefills. This can be different from max_q_len when
                # we have a non-uniform batch with some short decodes offloaded
                # to the prefill pathway
                query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
679
                attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())
680

681
682
683
684
685
686
687
688
689
690
691
                if not attn_metadata.prefill_use_trtllm:
                    attn_metadata.prefill_wrapper.plan(
                        qo_indptr_cpu,
                        paged_kv_indptr_cpu,
                        paged_kv_indices,
                        paged_kv_last_page_len_cpu[prefill_start:],
                        self.num_qo_heads,
                        self.num_kv_heads,
                        self.head_dim,
                        self.page_size,
                        causal=True,
692
693
694
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
695
696
                        q_data_type=self.q_data_type,
                        kv_data_type=self.kv_cache_dtype,
697
698
                        fixed_split_size=self.prefill_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
699
700
                    )
                else:
701
                    attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
702
703
                        self.device, non_blocking=True
                    )
704
                    attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
705
706
                        self.device, non_blocking=True
                    )
707
708
709
710

            if num_decodes > 0:
                pure_decode = num_prefills == 0
                # possible required padding for cudagraph replay
711
712
713
                use_cudagraph = (
                    self.enable_cuda_graph
                    and pure_decode
714
                    and num_decode_tokens <= self._decode_cudagraph_max_bs
715
                )
716
                if use_cudagraph:
717
718
719
                    num_input_tokens = self.vllm_config.pad_for_cudagraph(
                        num_decode_tokens
                    )
720
721
722
                    # Carefully fulfill the padding region with reasonable value
                    # on cpu.
                    # Make sure paged_kv_indptr_cpu is not decreasing
723
724
725
                    self.paged_kv_indptr_cpu[
                        1 + num_decodes : 1 + num_input_tokens
                    ].fill_(paged_kv_indptr_cpu[-1])
726
727
728
                    # Fill the remaining paged_kv_last_page_len_cpu with 1.
                    # This is because flashinfer treats 0 as a full page
                    # instead of empty.
729
730
731
                    self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
                        1
                    )
732
733

                else:
734
                    num_input_tokens = num_decode_tokens
735
736

                attn_metadata.decode_wrapper = self._get_decode_wrapper(
737
738
                    num_input_tokens, use_cudagraph
                )
739
740
741
742
743
744
                if not attn_metadata.decode_use_trtllm:
                    # Use the persistent buffer with padding length,
                    # instead of the same address but chunked version
                    # in atten_metadata when using cudagraph.
                    fast_plan_decode(
                        attn_metadata.decode_wrapper,
745
                        self.paged_kv_indptr_cpu[: num_input_tokens + 1],
746
747
748
749
750
751
752
753
754
                        paged_kv_indices,
                        self.paged_kv_last_page_len_cpu[:num_input_tokens],
                        seq_lens_cpu[:num_input_tokens],
                        self.num_qo_heads,
                        self.num_kv_heads,
                        self.head_dim,
                        self.page_size,
                        # Disable flashinfer's pos encoding and use vllm's rope.
                        pos_encoding_mode="NONE",
755
756
757
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
758
759
                        q_data_type=self.q_data_type,
                        kv_data_type=self.kv_cache_dtype,
760
761
                        fixed_split_size=self.decode_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
762
                    )
763
764
765
        return attn_metadata

    def use_cascade_attention(self, *args, **kwargs) -> bool:
766
        if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
767
768
769
            # TODO: The cascade wrapper currently does not support setting
            # kv cache dtype to something different from query dtype.
            return False
770
771
772
        # TODO: Cascade attention doesn't work, disable it for now
        # return use_cascade_attention(*args, **kwargs)
        return False
773
774
775
776
777
778
779
780
781


class FlashInferImpl(AttentionImpl):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
782
783
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
784
        kv_cache_dtype: str,
785
        logits_soft_cap: float | None = None,
786
        attn_type: AttentionType = AttentionType.DECODER,
787
788
        kv_sharing_target_layer_name: int | None = None,
        sinks: torch.Tensor | None = None,
789
790
791
792
793
794
795
796
797
798
799
800
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
801
802
803
        self.window_left = (
            self.sliding_window[0] if self.sliding_window is not None else -1
        )
804
805
        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
806
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
807
808
809
810

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if attn_type != AttentionType.DECODER:
811
812
813
814
815
816
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashInferImpl"
            )
817

818
        self.sinks: torch.Tensor | None = None
819
        if sinks is not None:
820
821
822
823
            if sinks.shape[0] != num_heads:
                raise ValueError(
                    "Sinks must have the same number of heads as the number of "
                    f"heads in the layer. Expected {num_heads}, but got "
824
825
                    f"{sinks.shape[0]}."
                )
826
827
            self.sinks = sinks

828
        self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
829
830
831
        self.bmm1_scale: float | None = None
        self.bmm2_scale: float | None = None
        self.o_sf_scale: float | None = None
832

833
    def fused_output_quant_supported(self, quant_key: QuantKey):
834
835
836
837
838
        return (
            self.support_trtllm_attn
            and self.kv_cache_dtype.startswith("fp8")
            and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
        )
839

840
841
842
843
844
845
    def supports_quant_query_input(self) -> bool:
        if flashinfer_disable_q_quantization():
            return False

        return self.support_trtllm_attn

846
847
848
849
850
    # FlashInfer requires attention sinks to be float32
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if self.sinks is not None and self.sinks.dtype != torch.float32:
            self.sinks = self.sinks.to(torch.float32)

851
852
853
854
855
856
857
858
    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashInferMetadata,
859
860
861
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
862
863
864
865
866
867
868
    ) -> torch.Tensor:
        """Forward pass with FlashInfer.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
869
870
871
            kv_cache: KV cache tensor with different possible shapes:
                - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
                - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
872
873
874
875
876
877
878
879
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if attn_metadata is None:
            # Profiling run.
880
            return output.fill_(0)
881

882
883
884
885
886
887
        # Ensure query dtype matches the expected dtype from attention metadata
        assert attn_metadata.q_data_type == query.dtype, (
            f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
            f"got {query.dtype}"
        )

888
        if self.bmm1_scale is None:
889
            self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
890
891
892
893
894
895

        if self.bmm2_scale is None:
            self.bmm2_scale = layer._v_scale_float

        # The attn+quant fusion happens when output_scale is provided.
        if output_scale is None:
896
897
898
            assert output_block_scale is None, (
                "output_block_scale is not supported when fusion has not happened"
            )
899
        else:
900
            assert attn_metadata.q_data_type == FP8_DTYPE, (
901
                "Query must be FP8 when attn+quant fusion happened."
902
903
904
905
            )
            assert (
                attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm
            ), "Must use TRT-LLM attn"
906

907
            if output.dtype == FP8_DTYPE:
908
                assert output_block_scale is None, (
909
                    "output_block_scale should not be provided for fp8 output"
910
                )
911
            elif output.dtype == FP4_DTYPE:
912
                assert output_block_scale is not None, (
913
                    "output_block_scale is required for nvfp4 output"
914
                )
915
916
917
            else:
                raise ValueError(f"Unsupported output dtype: {output.dtype}")

918
            # TRTLLM attn kernel requires to scale to pass as a host scalar,
919
920
            # store the o scale as a host scalar in warmup run with cuda graph
            # not enabled
921
922
            if layer._o_scale_float is None:
                layer._o_scale_float = output_scale.cpu().item()
923
924
925
926
                if output.dtype == FP8_DTYPE:
                    self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
                elif output.dtype == FP4_DTYPE:
                    self.o_sf_scale = layer._o_scale_float
927

928
929
930
931
932
933
934
935
936
937
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956

        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
957

958
959
960
961
            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if self.kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
962
963
                    self.kv_cache_dtype
                )
964
965
                kv_cache = kv_cache.view(torch_dtype)

966
967
968
969
970
971
972
973
974
975
976
        # Inputs and outputs may be padded for CUDA graphs
        query = query[:num_actual_tokens]
        output_padded = output
        output = output[:num_actual_tokens]

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

977
978
979
        # When using spec decoding, num_decodes can be < num_decode_tokens
        # because some decode requests may have more than one query token.
        num_decodes = attn_metadata.num_decodes
980
981
982
        num_decode_tokens = attn_metadata.num_decode_tokens
        num_prefill_tokens = attn_metadata.num_prefill_tokens

983
        stride_order = FlashInferBackend.get_kv_cache_stride_order()
984
        kv_cache_permute = kv_cache.permute(*stride_order)
985
        # Regular attention (common case).
986
        # Decodes are at the front and prefills are at the back.
987
988
        if num_prefill_tokens > 0:
            prefill_wrapper = attn_metadata.prefill_wrapper
989
990
991
            prefill_query = query[num_decode_tokens:]
            assert prefill_query.shape[0] == num_prefill_tokens
            assert prefill_wrapper is not None
992
993
994

            if not attn_metadata.prefill_use_trtllm:
                assert prefill_wrapper._causal
995
                assert prefill_wrapper._window_left == self.window_left
996
                assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
                assert prefill_wrapper._sm_scale == self.scale
                prefill_wrapper.run(
                    prefill_query,
                    kv_cache_permute,
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output[num_decode_tokens:],
                )
            else:
                # prefill_query may be non-contiguous
                prefill_query = prefill_query.contiguous()
1008
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
1009
                block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:]
1010
                seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
1011
1012
1013
1014
1015
1016
1017
1018
1019

                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                assert get_kv_cache_layout() == "HND"
                assert prefill_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_prefill.is_contiguous()
                assert seq_lens_prefill.is_contiguous()

1020
1021
                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
1022
1023
1024
1025
1026
1027
                    out = FP4Tensor(
                        data=output[num_decode_tokens:],
                        scale=output_block_scale,
                        scale_start_index=num_decode_tokens,
                        original_shape=prefill_query.shape,
                    )
1028
1029
1030
1031
                else:
                    assert self.o_sf_scale is None
                    out = output[num_decode_tokens:]

1032
1033
1034
1035
                if (
                    attn_metadata.q_data_type != FP8_DTYPE
                    and self.kv_cache_dtype.startswith("fp8")
                ):
1036
1037
1038
1039
                    # TRTLLM prefill attention does not support BF16 Q
                    # and fp8 kv cache. So to enable prefill attention
                    # with fp8 kv cache, we can construct a mock block
                    # and mock kv cache with BF16 KV involved in the prefill
1040
1041
1042
1043
1044
1045
1046
                    mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
                        kv_cache_permute,
                        block_tables_prefill,
                        layer._k_scale,
                        layer._v_scale,
                        attn_metadata.q_data_type,
                    )
1047
1048
1049
1050
                else:
                    mock_kv_cache = kv_cache_permute
                    mock_block_table = block_tables_prefill

1051
1052
                trtllm_batch_context_with_kv_cache(
                    query=prefill_query,
1053
                    kv_cache=mock_kv_cache,
1054
                    workspace_buffer=workspace_buffer,
1055
                    block_tables=mock_block_table,
1056
                    seq_lens=seq_lens_prefill,
1057
                    max_q_len=attn_metadata.max_q_len_prefill,
1058
                    max_kv_len=attn_metadata.max_seq_len,
1059
1060
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
1061
1062
1063
                    batch_size=attn_metadata.num_prefills,
                    cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
                    cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
1064
                    window_left=self.window_left,
1065
                    sinks=self.sinks,
1066
1067
                    o_sf_scale=self.o_sf_scale,
                    out=out,
1068
1069
1070
1071
                )

        if num_decode_tokens > 0:
            decode_wrapper = attn_metadata.decode_wrapper
1072
1073
            decode_query = query[:num_decode_tokens]
            assert decode_query.shape[0] == num_decode_tokens
1074
            assert decode_wrapper is not None
1075
1076

            if not attn_metadata.decode_use_trtllm:
1077
                assert decode_wrapper._window_left == self.window_left
1078
                assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
1079
1080
1081
                assert decode_wrapper._sm_scale == self.scale
                decode_wrapper.run(
                    decode_query,
1082
                    kv_cache_permute,
1083
1084
1085
1086
1087
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output[:num_decode_tokens],
                )
            else:
1088
1089
                # decode_query may be non-contiguous
                decode_query = decode_query.contiguous()
1090
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
1091
1092
1093
                block_tables_decode = attn_metadata.block_table_tensor[
                    :num_decode_tokens
                ]
1094
1095
                seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

1096
                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
1097
1098
1099
1100
1101
1102
1103
                assert get_kv_cache_layout() == "HND"
                assert decode_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_decode.is_contiguous()
                assert seq_lens_decode.is_contiguous()

1104
1105
                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
1106
1107
1108
1109
1110
1111
                    out = FP4Tensor(
                        data=output[:num_decode_tokens],
                        scale=output_block_scale,
                        scale_start_index=0,
                        original_shape=decode_query.shape,
                    )
1112
1113
1114
1115
                else:
                    assert self.o_sf_scale is None
                    out = output[:num_decode_tokens]

1116
1117
1118
1119
1120
                if num_decode_tokens % attn_metadata.num_decodes != 0:
                    # This gets triggered when the dummy_run forces
                    # attention to be initialized with q_len = 0
                    q_len_per_req = 1
                else:
1121
                    q_len_per_req = num_decode_tokens // attn_metadata.num_decodes
1122

1123
1124
1125
1126
1127
1128
1129
                trtllm_batch_decode_with_kv_cache(
                    query=decode_query,
                    kv_cache=kv_cache_permute,
                    workspace_buffer=workspace_buffer,
                    block_tables=block_tables_decode,
                    seq_lens=seq_lens_decode,
                    max_seq_len=attn_metadata.max_seq_len,
1130
1131
1132
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
                    window_left=self.window_left,
1133
                    sinks=self.sinks,
1134
1135
                    o_sf_scale=self.o_sf_scale,
                    out=out,
1136
1137
                    q_len_per_req=q_len_per_req,
                )
1138
        return output_padded
1139
1140
1141
1142
1143
1144
1145


def fast_plan_decode(
    self,  # decode wrapper
    indptr_cpu: torch.Tensor,
    indices: torch.Tensor,
    last_page_len_cpu: torch.Tensor,
1146
    seq_lens_cpu: torch.Tensor,
1147
1148
1149
1150
1151
1152
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    window_left: int = -1,
1153
    logits_soft_cap: float | None = None,
1154
1155
1156
    q_data_type: str | torch.dtype | None = "float16",
    kv_data_type: str | torch.dtype | None = None,
    data_type: str | torch.dtype | None = None,
1157
1158
1159
    sm_scale: float | None = None,
    rope_scale: float | None = None,
    rope_theta: float | None = None,
1160
    non_blocking: bool = True,
1161
1162
    fixed_split_size: int = -1,
    disable_split_kv: bool = False,
1163
1164
) -> None:
    """
1165
1166
    A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
    cudagraph capture/replay, while the no cudagraph version turns back
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
    to the original plan.
    using original plan after passing host-side buffers:
    - only host-to-device copy of indptr and last_page_len buffers
    Modifications for cudagraph:
    - only host-to-device copy of indptr and last_page_len buffers.
    - avoid device-to-device copy of indices buffer.

    Part of the code get inspiration from the original plan from FlashInfer repo
    and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
    """
    # Warm up with the original plan if it is first call, and always run the
    # original plan if we run for dynamic shape. For fixed shape (cudagraph),
    # this warm up is to generate the _cached_module for the decode wrapper.
1180
    if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
        self.plan(
            indptr_cpu,
            indices,
            last_page_len_cpu,
            num_qo_heads,
            num_kv_heads,
            head_dim,
            page_size,
            pos_encoding_mode,
            window_left,
            logits_soft_cap,
            q_data_type,
            kv_data_type,
            data_type,
            sm_scale,
            rope_scale,
            rope_theta,
            non_blocking,
1199
1200
1201
1202
            None,  # block_tables
            None,  # seq_lens
            fixed_split_size,
            disable_split_kv,
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        )
        self.vllm_first_call = False
        return

    assert self.is_cuda_graph_enabled, "Should be cudagraph only here"

    batch_size = len(last_page_len_cpu)
    if logits_soft_cap is None:
        logits_soft_cap = 0.0

    # Handle data types consistently
    if data_type is not None:
        if q_data_type is None:
            q_data_type = data_type
        if kv_data_type is None:
            kv_data_type = data_type
    elif q_data_type is None:
        q_data_type = "float16"

    if kv_data_type is None:
        kv_data_type = q_data_type
1224
1225
1226
1227
1228
1229
    q_data_type = (
        getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
    )
    kv_data_type = (
        getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
    )
1230
1231
1232
1233
1234

    if batch_size != self._fixed_batch_size:
        raise ValueError(
            "The batch size should be fixed in cudagraph mode, the runtime "
            "batch size {} mismatches the batch size set during "
1235
1236
            "initialization {}".format(batch_size, self._fixed_batch_size)
        )
1237
1238
    if len(indices) > len(self._paged_kv_indices_buf):
        raise ValueError(
1239
1240
            "The size of indices should be less than or equal to the allocated buffer"
        )
1241
1242
1243
1244

    # host-to-device copy for the indptr buffer
    self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
    # host-to-device copy for the last_page_len buffer
1245
    self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True)
1246

1247
1248
1249
    qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")

    try:
1250
        # Make sure we pass exactly 18 arguments for tensor core version
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
        self._plan_info = self._cached_module.plan(
            self._float_workspace_buffer,
            self._int_workspace_buffer,
            self._pin_memory_int_workspace_buffer,
            qo_indptr_host,
            indptr_cpu,
            seq_lens_cpu,
            batch_size,  # total_num_rows
            batch_size,
            num_qo_heads,
            num_kv_heads,
            page_size,
            self.is_cuda_graph_enabled,
            head_dim,
            head_dim,
            False,  # causal
1267
            window_left,
1268
1269
            fixed_split_size,
            disable_split_kv,
1270
1271
1272
        )
    except Exception as e:
        raise RuntimeError(f"Error in tensor core plan: {e}") from e
1273
1274
1275
1276
1277
1278
1279

    self._pos_encoding_mode = pos_encoding_mode
    self._window_left = window_left
    self._logits_soft_cap = logits_soft_cap
    self._sm_scale = sm_scale
    self._rope_scale = rope_scale
    self._rope_theta = rope_theta
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298


@triton.jit
def _copy_page_indices_kernel(
    page_indices,
    block_table,
    block_table_stride,
    cu_num_blocks,
    BLOCK_SIZE: tl.constexpr,
):
    req_idx = tl.program_id(0)
    row_ptr = block_table + req_idx * block_table_stride
    start_idx = tl.load(cu_num_blocks + req_idx)
    end_idx = tl.load(cu_num_blocks + req_idx + 1)
    num_blocks = end_idx - start_idx

    offset = tl.arange(0, BLOCK_SIZE)
    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
1299
1300
1301
1302
1303
        tl.store(
            page_indices + start_idx + i + offset,
            block_ids,
            mask=i + offset < num_blocks,
        )