flashinfer.py 61.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer with FlashInfer."""
4

5
from dataclasses import dataclass
6
from typing import ClassVar
7

8
import numpy as np
9
import torch
10
11
12
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
13
    BatchPrefillWithRaggedKVCacheWrapper,
14
15
    MultiLevelCascadeAttentionWrapper,
)
16
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
17
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
18
from flashinfer.utils import FP4Tensor
19

20
from vllm import envs
21
22
23
24
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionImpl,
    AttentionType,
25
    MultipleOf,
26
)
27
28
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
29
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
30
from vllm.config.cache import CacheDType
31
from vllm.distributed.parallel_state import get_dcp_group
32
from vllm.logger import init_logger
33
from vllm.model_executor.layers.batch_invariant import (
34
    vllm_is_batch_invariant,
35
)
36
from vllm.model_executor.layers.quantization.utils.quant_utils import (
37
38
39
40
    QuantKey,
    kFp8StaticTensorSym,
    kNvfp4Quant,
)
41
from vllm.platforms import current_platform
42
from vllm.platforms.interface import DeviceCapability
43
from vllm.triton_utils import tl, triton
44
45
46
47
from vllm.utils.flashinfer import (
    can_use_trtllm_attention,
    use_trtllm_attention,
)
48
from vllm.utils.math_utils import cdiv
49
from vllm.utils.platform_utils import is_pin_memory_available
50
51
52
53
from vllm.v1.attention.backends.utils import (
    AttentionCGSupport,
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
54
    KVCacheLayoutType,
55
    get_dcp_local_seq_lens,
56
57
58
59
60
    get_kv_cache_layout,
    get_per_layer_parameters,
    infer_global_hyperparameters,
    split_decodes_and_prefills,
)
61
from vllm.v1.kv_cache_interface import AttentionSpec
62

63
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
64

65
FP8_DTYPE = current_platform.fp8_dtype()
66
FP4_DTYPE = torch.uint8
67

68
69
logger = init_logger(__name__)

70
71
72
73
74
75
76
trtllm_gen_workspace_buffer = None


def _get_trtllm_gen_workspace_buffer():
    global trtllm_gen_workspace_buffer
    if trtllm_gen_workspace_buffer is None:
        trtllm_gen_workspace_buffer = torch.zeros(
77
            envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
78
        )
79
80
    return trtllm_gen_workspace_buffer

81

82
83
84
85
86
87
88
89
90
91
92
93
94
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
    kv_cache_ptr,
    block_tables_prefill_ptr,
    block_table_stride,
    mock_kv_cache_ptr,
    k_scale_ptr,
    v_scale_ptr,
    K_CACHE_STRIDE: tl.constexpr,
    KV_CACHE_STRIDE: tl.constexpr,
):
    batch_idx = tl.program_id(0).to(tl.int64)
    mock_block_table_idx = tl.program_id(1).to(tl.int64)
95
96
97
    orig_page_num = tl.load(
        block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx
    ).to(tl.int64)
98
99
100
101
102
103
104
105
106
    if orig_page_num <= 0:
        return
    dequant_dtype = mock_kv_cache_ptr.dtype.element_ty

    # Dequantize K
    k_scale_val = tl.load(k_scale_ptr)
    offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
107
108
109
    mock_cache_offset = (
        batch_idx * block_table_stride + mock_block_table_idx + 1
    ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
110
111
112
113
114
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)

    # Dequantize V
    v_scale_val = tl.load(v_scale_ptr)
115
116
117
    offset = (
        orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    )
118
119
120
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
    mock_cache_offset = (
121
122
123
124
        (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE
        + K_CACHE_STRIDE
        + tl.arange(0, K_CACHE_STRIDE)
    )
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)


def trtllm_prefill_attn_kvfp8_dequant(
    kv_cache: torch.Tensor,
    block_tables_prefill: torch.Tensor,
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    dequant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
    batch_size, num_of_page_per_token = block_tables_prefill.shape
    s = kv_cache.shape
    assert s[1] == 2
    assert dequant_dtype in (torch.bfloat16, torch.float16)
    k_cache_stride = s[2] * s[3] * s[4]
    kv_cache_stride = k_cache_stride * s[1]
    new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
    # mock kv cache contains just the pages needed by this prefill
144
    mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    # we simply sequentially index the pages needed by this prefill
    mock_block_table = torch.arange(
        start=1,
        end=batch_size * num_of_page_per_token + 1,
        dtype=torch.int32,
        device=block_tables_prefill.device,
    ).reshape(batch_size, num_of_page_per_token)
    grid = (batch_size, num_of_page_per_token)
    _trtllm_prefill_attn_kvfp8_dequant[grid](
        kv_cache,
        block_tables_prefill,
        num_of_page_per_token,
        mock_kv_cache,
        k_scale,
        v_scale,
        k_cache_stride,
        kv_cache_stride,
    )
    return mock_kv_cache, mock_block_table

165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class BatchDCPPrefillWrapper:
    def __init__(
        self,
        workspace_buffer: torch.Tensor | None = None,
    ):
        self._context = BatchPrefillWithPagedKVCacheWrapper(
            workspace_buffer, get_kv_cache_layout()
        )
        self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(
            workspace_buffer, get_kv_cache_layout()
        )

    def plan(
        self,
        qo_indptr_cpu: torch.Tensor,
        paged_kv_indptr_cpu: torch.Tensor,
        paged_kv_indices: torch.Tensor,
        paged_kv_last_page_len_cpu: torch.Tensor,
        prefill_start: int,
        page_size: int,
        num_qo_heads: int,
        dcp_world_size: int,
        num_kv_heads: int,
        head_dim: int,
        sm_scale: float,
        window_left: int,
        logits_soft_cap: float | None,
        q_data_type: torch.dtype,
        kv_cache_dtype: torch.dtype,
        prefill_fixed_split_size: int,
        disable_split_kv: bool,
    ):
        """Plan the prefill operation with given parameters."""
        self._context.plan(
            qo_indptr_cpu,
            paged_kv_indptr_cpu,
            paged_kv_indices,
            paged_kv_last_page_len_cpu[prefill_start:],
            num_qo_heads * dcp_world_size,
            num_kv_heads,
            head_dim,
            page_size,
            causal=False,  # This is context run
            sm_scale=sm_scale,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
            kv_data_type=kv_cache_dtype,
            fixed_split_size=prefill_fixed_split_size,
            disable_split_kv=disable_split_kv,
        )
        self._new_tokens.plan(
            qo_indptr=qo_indptr_cpu,
            kv_indptr=qo_indptr_cpu,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim_qk=head_dim,
            head_dim_vo=head_dim,
            causal=True,  # This is newtokens run
            sm_scale=sm_scale,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
        )

    def run(
        self,
        layer: torch.nn.Module,
        prefill_query: torch.Tensor,
        kv_cache_permute: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        out: torch.Tensor,
    ):
        prefill_query_across_dcp = get_dcp_group().all_gather(
            prefill_query.contiguous(), dim=1
        )
        output_context_tmp, lse_context_tmp = self._context.run(
            prefill_query_across_dcp,
            kv_cache_permute,
            k_scale=layer._k_scale_float,
            v_scale=layer._v_scale_float,
            return_lse=True,
        )
        output_context, lse_context = cp_lse_ag_out_rs(
251
252
253
254
255
            output_context_tmp,
            lse_context_tmp,
            get_dcp_group(),
            return_lse=True,
            is_lse_base_on_e=False,
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        )
        lse_context = lse_context.transpose(0, 1).contiguous()

        output_query, lse_query = self._new_tokens.run(
            prefill_query,
            key,
            value,
            return_lse=True,
        )
        lse_query = lse_query.transpose(0, 1).contiguous()

        merge_attn_states(
            out,
            output_context,
            lse_context,
            output_query,
            lse_query,
        )
        return out


277
class FlashInferBackend(AttentionBackend):
278
    accept_output_buffer: bool = True
279
280
281
282
283
284
285
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
        "fp8",
        "fp8_e4m3",
        "fp8_e5m2",
    ]
286

287
288
289
290
291
292
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        # Note: Not sure for all platforms, but on Blackwell,
        # only support a page size of 16, 32, 64.
        return [16, 32, 64]

293
294
    @staticmethod
    def get_name() -> str:
295
        return "FLASHINFER"
296
297

    @staticmethod
298
    def get_impl_cls() -> type["FlashInferImpl"]:
299
300
301
        return FlashInferImpl

    @staticmethod
302
    def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
303
304
305
306
307
308
309
310
        return FlashInferMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
311
        cache_dtype_str: str = "auto",
312
313
314
    ) -> tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

315
    @staticmethod
316
317
318
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
319
320
321
        # `stride_order` indicates the permutation that gets us from
        # `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
322
323
324
325
        if cache_layout == "NHD" and include_num_layers_dimension:
            # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
            return (1, 0, 2, 3, 4, 5)
        elif cache_layout == "NHD":
326
            stride_order = (0, 1, 2, 3, 4)
327
328
329
        elif cache_layout == "HND" and include_num_layers_dimension:
            # (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
            return (1, 2, 4, 0, 3, 5)
330
331
332
333
334
335
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

336
337
338
339
340
341
342
343
344
    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

345
346
347
348
349
350
351
352
353
354
355
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        return [64, 128, 256]

    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability(
            12, 1
        )

356
357
358
359
360
361
362
363
    @classmethod
    def supports_sink(cls) -> bool:
        """FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
        from vllm.utils.flashinfer import (
            force_use_trtllm_attention,
            supports_trtllm_attention,
        )

364
365
        # Respect explicit disable flag (e.g.,
        # --attention-config.use_trtllm_attention=0)
366
367
368
369
370
371
        if force_use_trtllm_attention() is False:
            return False

        # Check if TRTLLM is supported on this platform
        return supports_trtllm_attention()

372
373
374
375
376
377
378
379
380
    @classmethod
    def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
        from vllm.platforms import current_platform

        capability = current_platform.get_device_capability()
        if capability is not None and capability.major == 10:
            return "HND"
        return None

381
382
383
384
385
386
387
388
389
390

@dataclass
class FlashInferMetadata:
    num_actual_tokens: int  # Number of tokens excluding padding.

    # The data type of the query
    q_data_type: torch.dtype

    slot_mapping: torch.Tensor

391
    # For flashinfer trtllm batch decode
392
    max_q_len: int
393
    max_q_len_prefill: int
394
395
396
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table_tensor: torch.Tensor
397
398
    prefill_use_trtllm: bool
    decode_use_trtllm: bool
399

400
401
402
403
404
405
    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

406
    # For cascade attention (CPU for planning).
407
408
    use_cascade: bool

409
410
411
    prefill_wrapper: (
        BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
    ) = None
412
413
    decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
    cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
414

415
416
    qo_indptr_gpu: torch.Tensor | None = None
    paged_kv_indptr_gpu: torch.Tensor | None = None
417

418

419
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
420
    reorder_batch_threshold: int = 1
421

422
423
424
425
426
427
428
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
429
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
430
        self.cache_config = vllm_config.cache_config
431
        self.model_config = vllm_config.model_config
432
        self._workspace_buffer = None
433
434
435
        self._prefill_wrapper: (
            BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
        ) = None  # Wrapper for prefill/append
436
437
        self._decode_wrapper = None  # Wrapper for decode (general shape)

438
        if vllm_is_batch_invariant():
439
440
441
442
443
444
445
446
            self.decode_fixed_split_size = 2048
            self.prefill_fixed_split_size = 4096
            self.disable_split_kv = True
        else:
            self.decode_fixed_split_size = -1
            self.prefill_fixed_split_size = -1
            self.disable_split_kv = False

447
        self.compilation_config = vllm_config.compilation_config
448
449
450
        max_num_pages_per_req = cdiv(
            self.model_config.max_model_len, self.kv_cache_spec.block_size
        )
451
452
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        max_num_pages = max_num_reqs * max_num_pages_per_req
453
454
455
456
457
458
        speculative_config = vllm_config.speculative_config
        num_spec_tokens = (
            speculative_config.num_speculative_tokens
            if speculative_config is not None
            else 0
        )
459
460
461
        self.enable_cuda_graph = (
            self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
        )
462
463
464
465
        if self.enable_cuda_graph:
            # For full cudagraph capture, one `decode_wrapper` for each batch
            # size is needed for FlashInfer.
            self._decode_wrappers_cudagraph: dict[
466
467
                int, BatchDecodeWithPagedKVCacheWrapper
            ] = {}
468
            self._decode_cudagraph_max_bs = min(
469
                (1 + num_spec_tokens) * max_num_reqs,
470
                self.compilation_config.max_cudagraph_capture_size,
471
            )
472

473
474
475
476
477
478
479
480
481
482
483
484
        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
            self.dcp_kv_cache_interleave_size = (
                vllm_config.parallel_config.dcp_kv_cache_interleave_size
            )
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
            self.dcp_kv_cache_interleave_size = 1

485
486
        self.num_qo_heads = self.model_config.get_num_attention_heads(
            self.vllm_config.parallel_config
487
        )
488

489
490
491
492
493
494
        self.num_kv_heads = self.kv_cache_spec.num_kv_heads
        self.head_dim = self.kv_cache_spec.head_size
        self.page_size = self.kv_cache_spec.block_size

        self.cache_dtype = self.cache_config.cache_dtype
        if self.cache_dtype.startswith("fp8"):
495
496
497
            self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.cache_dtype
            )
498
        else:
499
            assert self.kv_cache_spec.dtype == self.model_config.dtype
500
            self.kv_cache_dtype = self.kv_cache_spec.dtype
501

502
        # Use model dtype as q dtype when TRTLLM attn is not supported, or
503
504
        # --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
        # try to use fp8 q if kv cache is fp8, and will fall back to model dtype
505
        # if TRTLLM attention kernel is not used when building attn metadata
506
        can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
507
508
509
510
        if (
            can_use_trtllm
            and not vllm_config.attention_config.disable_flashinfer_q_quantization
        ):
511
512
513
            self.q_data_type = self.kv_cache_dtype
        else:
            self.q_data_type = self.model_config.dtype
514

515
516
517
        # Prefer TRTLLM attention for decoding in all cases.
        # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
        self.use_trtllm_decode_attention = can_use_trtllm
518
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
519

520
521
522
        self._cascade_wrapper = None  # Wrapper for cascade attention

        # Global hyperparameters shared by all attention layers
523
        # TODO: discard this for trtllm-gen backend
524
        self.global_hyperparameters = infer_global_hyperparameters(
525
526
            get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)
        )
527
528
529
530
        self.sm_scale = self.global_hyperparameters.sm_scale
        self.window_left = self.global_hyperparameters.window_left
        self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
        self.has_sinks = self.global_hyperparameters.has_sinks
531
        if self.has_sinks and not can_use_trtllm:
532
533
534
            raise NotImplementedError(
                "FlashInfer backend currently does not support attention "
                "sinks, please use trtllm on blackwell or flash attention on "
535
536
                "earlier GPUs."
            )
537
        # Preparing persistent buffers (device-side)
538
539
540
        self.paged_kv_indptr = torch.zeros(
            max_num_reqs + 1, dtype=torch.int32, device=self.device
        )
541
542
543
        self.paged_kv_indices = torch.zeros(
            max_num_pages,  # max num pages possible
            dtype=torch.int32,
544
545
546
547
548
            device=self.device,
        )
        self.paged_kv_last_page_len = torch.zeros(
            max_num_reqs, dtype=torch.int32, device=self.device
        )
549
550
        # host-side buffer
        pin_memory = is_pin_memory_available()
551
552
553
        self.paged_kv_indptr_cpu = torch.zeros(
            max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
554
        self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
555
        self.paged_kv_indptr_buffer = torch.zeros_like(
556
557
558
559
560
561
562
563
564
            self.paged_kv_indptr_cpu, pin_memory=pin_memory
        )
        self.paged_kv_indices_cpu = torch.zeros(
            max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
        self.paged_kv_last_page_len_cpu = torch.zeros(
            max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
        self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
565

566
567
568
569
570
571
572
573
574
        if self.head_dim == 256 and current_platform.is_device_capability(100):
            # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
            # head size 256 and block size 16 is not supported on blackwell.
            assert kv_cache_spec.block_size != 16, (
                "There is a bug in FlashInfer "
                "block_size 16 head size 256 support. Please avoid this combination by "
                "passing --block-size 32 or --block-size 64."
            )

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    @classmethod
    def get_cudagraph_support(
        cls: type["FlashInferMetadataBuilder"],
        vllm_config: VllmConfig,
        kv_cache_spec: AttentionSpec,
    ) -> AttentionCGSupport:
        has_trtllm_support = can_use_trtllm_attention(
            num_qo_heads=vllm_config.model_config.get_num_attention_heads(
                vllm_config.parallel_config
            ),
            num_kv_heads=kv_cache_spec.num_kv_heads,
        )
        if has_trtllm_support:
            return AttentionCGSupport.UNIFORM_BATCH
        else:
            return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

592
593
    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
594
            buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
595
            if vllm_is_batch_invariant():
596
                buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
597
            self._workspace_buffer = torch.zeros(
598
                buffer_size, dtype=torch.uint8, device=self.device
599
            )
600
601
        return self._workspace_buffer

Woosuk Kwon's avatar
Woosuk Kwon committed
602
603
604
    def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
        self._workspace_buffer = workspace_buffer

605
606
607
    def _get_prefill_wrapper(
        self,
    ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
608
        if self._prefill_wrapper is None:
609
610
611
612
613
614
615
616
617
            if self.dcp_world_size > 1:
                self._prefill_wrapper = BatchDCPPrefillWrapper(
                    workspace_buffer=self._get_workspace_buffer(),
                )
            else:
                self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                    self._get_workspace_buffer(), get_kv_cache_layout()
                )
        assert self._prefill_wrapper is not None
618
619
        return self._prefill_wrapper

620
    def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
621
        if use_cudagraph:
622
            decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None)
623
624
625
626
627
        else:
            decode_wrapper = self._decode_wrapper

        if decode_wrapper is None:
            if use_cudagraph:
628
                paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1]
629
                paged_kv_indices = self.paged_kv_indices
630
                paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size]
631
632
633
634
635
            else:
                paged_kv_indptr = None
                paged_kv_indices = None
                paged_kv_last_page_len = None
            decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
636
                self._get_workspace_buffer(),
637
                get_kv_cache_layout(),
638
639
640
641
                use_cuda_graph=use_cudagraph,
                paged_kv_indptr_buffer=paged_kv_indptr,
                paged_kv_indices_buffer=paged_kv_indices,
                paged_kv_last_page_len_buffer=paged_kv_last_page_len,
642
                # Tensor cores are enabled by default because the perf would be
co63oc's avatar
co63oc committed
643
                # at least as good as cuda cores for all attention ops in latest
644
645
646
                # gpus.
                use_tensor_cores=True,
            )
647
648
649
650
651
652
653
654

            # save the decode wrapper
            if use_cudagraph:
                self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
            else:
                self._decode_wrapper = decode_wrapper

        return decode_wrapper
655
656
657
658

    def _get_cascade_wrapper(self):
        if self._cascade_wrapper is None:
            self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
659
660
                2, self._get_workspace_buffer(), get_kv_cache_layout()
            )
661
662
        return self._cascade_wrapper

663
664
665
666
667
668
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashInferMetadata:
669
        num_reqs = common_attn_metadata.num_reqs
670
        num_actual_tokens = common_attn_metadata.num_actual_tokens
671
672
673
674
675
676
677
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold,
                require_uniform=True,
            )
        )
678

679
        page_size = self.page_size
680
        max_q_len = common_attn_metadata.max_query_len
681
        max_seq_len = common_attn_metadata.max_seq_len
682
        seq_lens = common_attn_metadata.seq_lens
683
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
684
        block_table_tensor = common_attn_metadata.block_table_tensor
685
        qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
686

687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
        if self.dcp_world_size > 1:
            if num_prefills > 0:
                qo_indptr_prefill_cpu = (
                    qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
                )
                query_lens_prefill_cpu = (
                    qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
                )
                seq_lens_cpu[num_decodes:] = (
                    seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
                )

            seq_lens_cpu = get_dcp_local_seq_lens(
                seq_lens_cpu,
                self.dcp_world_size,
                self.dcp_rank,
                self.dcp_kv_cache_interleave_size,
            )

        seq_lens_np = seq_lens_cpu.numpy()
707
        num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
708
709
710
711
712
713

        use_cascade = common_prefix_len > 0
        if use_cascade:
            # Grab the blocks of the shared prefix from the first request.
            assert common_prefix_len % page_size == 0
            num_common_kv_blocks = common_prefix_len // page_size
714
715

            # Create CPU versions directly for cascade (no GPU versions needed)
716
717
718
719
720
721
722
723
724
725
            shared_qo_indptr_cpu = torch.tensor(
                [0, num_actual_tokens], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indptr_cpu = torch.tensor(
                [0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
            shared_kv_last_page_len_cpu = torch.tensor(
                [page_size], dtype=torch.int32, device="cpu"
            )
726

727
            # Remove the blocks of the shared prefix from all requests.
728
            block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
729
            num_blocks_np -= num_common_kv_blocks
730
        else:
731
732
733
734
735
            shared_qo_indptr_cpu = None
            shared_kv_page_indptr_cpu = None
            shared_kv_page_indices_cpu = None
            shared_kv_last_page_len_cpu = None

736
737
738
739
        # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
        np.cumsum(
            num_blocks_np,
            dtype=np.int32,
740
            out=self.paged_kv_indptr_np[1 : num_reqs + 1],
741
        )
742
743
744
        # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
        # after this line (e.g., for cuda graphs), we need to copy the data to
        # self.paged_kv_indptr_buffer to avoid race condition.
745
746
747
748
749
750
751
        self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
            : num_reqs + 1
        ]
        paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
        paged_kv_indptr.copy_(
            self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
        )
752

753
        # write self.paged_kv_indices inplace
754
        num_actual_pages = self.paged_kv_indptr_np[num_reqs]
755
        paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
756
        _copy_page_indices_kernel[(num_reqs,)](
757
758
759
760
761
762
            paged_kv_indices,
            block_table_tensor,
            block_table_tensor.stride(0),
            paged_kv_indptr,
            BLOCK_SIZE=1024,
        )
763

764
        # write self.paged_kv_last_page_len_cpu inplace
765
766
        paged_kv_last_page_len_np = seq_lens_np % page_size
        self.paged_kv_last_page_len_np[:num_reqs] = np.where(
767
            (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
768
769
770
            page_size,
            paged_kv_last_page_len_np,
        )
771

772
        uses_spec_reorder = self.reorder_batch_threshold > 1
773
774
775
776
777
        prefill_use_trtllm = use_trtllm_attention(
            self.num_qo_heads,
            self.num_kv_heads,
            num_prefill_tokens,
            max_seq_len,
778
            self.dcp_world_size,
779
780
781
782
783
784
            self.cache_dtype,
            self.q_data_type,
            is_prefill=True,
            has_sinks=self.has_sinks,
            has_spec=uses_spec_reorder,
        )
785
786
787
        decode_use_trtllm = (
            self.use_trtllm_decode_attention and self.dcp_world_size <= 1
        )
788
789

        if not (prefill_use_trtllm and decode_use_trtllm):
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
            if self.has_sinks:
                raise NotImplementedError(
                    "FlashInfer backend currently does not support attention "
                    "sinks, please use trtllm on blackwell or flash attention "
                    "on earlier GPUs."
                )

            if not self.global_hyperparameters.has_same_window_lefts:
                raise ValueError(
                    "Window left is not the same for all layers. "
                    "One potential fix is to set disable_sliding_window=True"
                )

            assert self.global_hyperparameters.has_same_all_params, (
                "FlashInfer backend currently only supports models in which "
                "all layers share the same values for the following "
                "hyperparameters: `window_left`, `logits_soft_cap`, "
                "`sm_scale`."
            )

            # The q quantization is not supported for non-trtllm attention,
            # fall back to model dtype.
812
813
            self.q_data_type = self.model_config.dtype

814
815
        attn_metadata = FlashInferMetadata(
            num_actual_tokens=num_actual_tokens,
816
            q_data_type=self.q_data_type,
817
            slot_mapping=common_attn_metadata.slot_mapping,
818
            max_q_len=max_q_len,
819
            max_q_len_prefill=max_q_len,
820
821
822
823
824
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
            block_table_tensor=block_table_tensor,
            prefill_use_trtllm=prefill_use_trtllm,
            decode_use_trtllm=decode_use_trtllm,
825
826
827
828
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
829
830
831
            use_cascade=use_cascade,
        )

832
        paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs]
833
        paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
834

835
836
837
838
839
840
841
842
843
844
845
846
        if attn_metadata.use_cascade:
            attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
            attn_metadata.cascade_wrapper.plan(
                [shared_qo_indptr_cpu, qo_indptr_cpu],
                [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
                [shared_kv_page_indices_cpu, paged_kv_indices],
                [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
                self.page_size,
                causal=True,
847
848
849
                sm_scale=self.sm_scale,
                window_left=self.window_left,
                logits_soft_cap=self.logits_soft_cap,
850
851
852
853
854
                q_data_type=self.q_data_type,
                kv_data_type=self.kv_cache_dtype,
            )
        else:
            # Regular attention (common case).
855
            # Decodes are at the front and prefills are at the back.
856
857
858
859
860
861
            num_prefills = attn_metadata.num_prefills
            num_decodes = attn_metadata.num_decodes
            if num_prefills > 0:
                # Decodes are first so prefills start after the last decode
                prefill_start = num_decodes
                attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
862
863
864
865
866
                assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
                assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
                assert (
                    paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills
                )
867
868
869
                # Since prefill_wrapper.run() will be called with
                # query[num_decode_tokens:] we need to adjust the qo_indptr
                # to be relative to the start of the prefill queries.
870
871
872
                qo_indptr_cpu = (
                    qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
                )
873
                paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
874
875
876
877
878
879

                # Recompute max_q_len for the slice of requests we are using
                # for prefills. This can be different from max_q_len when
                # we have a non-uniform batch with some short decodes offloaded
                # to the prefill pathway
                query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
880
                attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())
881

882
                if not attn_metadata.prefill_use_trtllm:
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
                    if self.dcp_world_size > 1:
                        assert isinstance(
                            attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
                        )
                        attn_metadata.prefill_wrapper.plan(
                            qo_indptr_cpu=qo_indptr_cpu,
                            paged_kv_indptr_cpu=paged_kv_indptr_cpu,
                            paged_kv_indices=paged_kv_indices,
                            paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
                            prefill_start=prefill_start,
                            page_size=self.page_size,
                            num_qo_heads=self.num_qo_heads,
                            dcp_world_size=self.dcp_world_size,
                            num_kv_heads=self.num_kv_heads,
                            head_dim=self.head_dim,
                            sm_scale=self.sm_scale,
                            window_left=self.window_left,
                            logits_soft_cap=self.logits_soft_cap,
                            q_data_type=self.q_data_type,
                            kv_cache_dtype=self.kv_cache_dtype,
                            prefill_fixed_split_size=self.prefill_fixed_split_size,
                            disable_split_kv=self.disable_split_kv,
                        )
                    else:
                        assert isinstance(
                            attn_metadata.prefill_wrapper,
                            BatchPrefillWithPagedKVCacheWrapper,
                        )
                        attn_metadata.prefill_wrapper.plan(
                            qo_indptr_cpu,
                            paged_kv_indptr_cpu,
                            paged_kv_indices,
                            paged_kv_last_page_len_cpu[prefill_start:],
                            self.num_qo_heads,
                            self.num_kv_heads,
                            self.head_dim,
                            self.page_size,
                            causal=True,
                            sm_scale=self.sm_scale,
                            window_left=self.window_left,
                            logits_soft_cap=self.logits_soft_cap,
                            q_data_type=self.q_data_type,
                            kv_data_type=self.kv_cache_dtype,
                            fixed_split_size=self.prefill_fixed_split_size,
                            disable_split_kv=self.disable_split_kv,
                        )
929
                else:
930
                    attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
931
932
                        self.device, non_blocking=True
                    )
933
                    attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
934
935
                        self.device, non_blocking=True
                    )
936
937
938

            if num_decodes > 0:
                pure_decode = num_prefills == 0
939
940
941
                use_cudagraph = (
                    self.enable_cuda_graph
                    and pure_decode
942
                    and num_decode_tokens <= self._decode_cudagraph_max_bs
943
                )
944
                num_input_tokens = num_decode_tokens
945
946

                attn_metadata.decode_wrapper = self._get_decode_wrapper(
947
948
                    num_input_tokens, use_cudagraph
                )
949
950
951
952
953
954
                if not attn_metadata.decode_use_trtllm:
                    # Use the persistent buffer with padding length,
                    # instead of the same address but chunked version
                    # in atten_metadata when using cudagraph.
                    fast_plan_decode(
                        attn_metadata.decode_wrapper,
955
                        self.paged_kv_indptr_cpu[: num_input_tokens + 1],
956
957
958
                        paged_kv_indices,
                        self.paged_kv_last_page_len_cpu[:num_input_tokens],
                        seq_lens_cpu[:num_input_tokens],
959
                        self.num_qo_heads * self.dcp_world_size,
960
961
962
963
964
                        self.num_kv_heads,
                        self.head_dim,
                        self.page_size,
                        # Disable flashinfer's pos encoding and use vllm's rope.
                        pos_encoding_mode="NONE",
965
966
967
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
968
969
                        q_data_type=self.q_data_type,
                        kv_data_type=self.kv_cache_dtype,
970
971
                        fixed_split_size=self.decode_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
972
                    )
973
974
975
        return attn_metadata

    def use_cascade_attention(self, *args, **kwargs) -> bool:
976
        if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
977
978
979
            # TODO: The cascade wrapper currently does not support setting
            # kv cache dtype to something different from query dtype.
            return False
980
981
982
        # TODO: Cascade attention doesn't work, disable it for now
        # return use_cascade_attention(*args, **kwargs)
        return False
983
984
985


class FlashInferImpl(AttentionImpl):
986
987
    can_return_lse_for_decode: bool = True

988
989
990
991
992
993
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
994
995
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
996
        kv_cache_dtype: str,
997
        logits_soft_cap: float | None = None,
998
        attn_type: AttentionType = AttentionType.DECODER,
999
1000
        kv_sharing_target_layer_name: int | None = None,
        sinks: torch.Tensor | None = None,
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
1013
1014
1015
        self.window_left = (
            self.sliding_window[0] if self.sliding_window is not None else -1
        )
1016
1017
        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
1018
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
1019
1020
1021
1022

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if attn_type != AttentionType.DECODER:
1023
1024
1025
1026
1027
1028
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashInferImpl"
            )
1029

1030
        self.sinks: torch.Tensor | None = None
1031
        if sinks is not None:
1032
1033
1034
1035
            if sinks.shape[0] != num_heads:
                raise ValueError(
                    "Sinks must have the same number of heads as the number of "
                    f"heads in the layer. Expected {num_heads}, but got "
1036
1037
                    f"{sinks.shape[0]}."
                )
1038
1039
            self.sinks = sinks

1040
        self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
1041
1042
1043
1044
1045
        vllm_config = get_current_vllm_config()
        self.supports_quant_query_input = (
            self.support_trtllm_attn
            and not vllm_config.attention_config.disable_flashinfer_q_quantization
        )
1046
1047
1048
        self.bmm1_scale: float | None = None
        self.bmm2_scale: float | None = None
        self.o_sf_scale: float | None = None
1049

1050
    def fused_output_quant_supported(self, quant_key: QuantKey):
1051
1052
1053
1054
1055
        return (
            self.support_trtllm_attn
            and self.kv_cache_dtype.startswith("fp8")
            and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
        )
1056

1057
1058
1059
1060
1061
    # FlashInfer requires attention sinks to be float32
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if self.sinks is not None and self.sinks.dtype != torch.float32:
            self.sinks = self.sinks.to(torch.float32)

1062
1063
1064
1065
1066
1067
1068
1069
    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashInferMetadata,
1070
1071
1072
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
1073
1074
1075
1076
1077
1078
1079
    ) -> torch.Tensor:
        """Forward pass with FlashInfer.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
1080
1081
1082
            kv_cache: KV cache tensor with different possible shapes:
                - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
                - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
1083
1084
1085
1086
1087
1088
1089
1090
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if attn_metadata is None:
            # Profiling run.
1091
            return output.fill_(0)
1092

1093
1094
1095
1096
1097
1098
        # Ensure query dtype matches the expected dtype from attention metadata
        assert attn_metadata.q_data_type == query.dtype, (
            f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
            f"got {query.dtype}"
        )

1099
        if self.bmm1_scale is None:
1100
            self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
1101
1102
1103
1104
1105
1106

        if self.bmm2_scale is None:
            self.bmm2_scale = layer._v_scale_float

        # The attn+quant fusion happens when output_scale is provided.
        if output_scale is None:
1107
1108
1109
            assert output_block_scale is None, (
                "output_block_scale is not supported when fusion has not happened"
            )
1110
        else:
1111
            assert attn_metadata.q_data_type == FP8_DTYPE, (
1112
                "Query must be FP8 when attn+quant fusion happened."
1113
1114
1115
1116
            )
            assert (
                attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm
            ), "Must use TRT-LLM attn"
1117

1118
            if output.dtype == FP8_DTYPE:
1119
                assert output_block_scale is None, (
1120
                    "output_block_scale should not be provided for fp8 output"
1121
                )
1122
            elif output.dtype == FP4_DTYPE:
1123
                assert output_block_scale is not None, (
1124
                    "output_block_scale is required for nvfp4 output"
1125
                )
1126
1127
1128
            else:
                raise ValueError(f"Unsupported output dtype: {output.dtype}")

1129
            # TRTLLM attn kernel requires to scale to pass as a host scalar,
1130
1131
            # store the o scale as a host scalar in warmup run with cuda graph
            # not enabled
1132
1133
            if layer._o_scale_float is None:
                layer._o_scale_float = output_scale.cpu().item()
1134
1135
1136
1137
                if output.dtype == FP8_DTYPE:
                    self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
                elif output.dtype == FP4_DTYPE:
                    self.o_sf_scale = layer._o_scale_float
1138

1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167

        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
1168

1169
1170
1171
1172
            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if self.kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
1173
1174
                    self.kv_cache_dtype
                )
1175
1176
                kv_cache = kv_cache.view(torch_dtype)

1177
1178
        # Inputs and outputs may be padded for CUDA graphs
        query = query[:num_actual_tokens]
1179
1180
        key = key[:num_actual_tokens]
        value = value[:num_actual_tokens]
1181
1182
1183
1184
1185
1186
1187
1188
1189
        output_padded = output
        output = output[:num_actual_tokens]

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

1190
1191
1192
        # When using spec decoding, num_decodes can be < num_decode_tokens
        # because some decode requests may have more than one query token.
        num_decodes = attn_metadata.num_decodes
1193
1194
1195
        num_decode_tokens = attn_metadata.num_decode_tokens
        num_prefill_tokens = attn_metadata.num_prefill_tokens

1196
        stride_order = FlashInferBackend.get_kv_cache_stride_order()
1197
        kv_cache_permute = kv_cache.permute(*stride_order)
1198
        # Regular attention (common case).
1199
        # Decodes are at the front and prefills are at the back.
1200
1201
        if num_prefill_tokens > 0:
            prefill_wrapper = attn_metadata.prefill_wrapper
1202
1203
1204
            prefill_query = query[num_decode_tokens:]
            assert prefill_query.shape[0] == num_prefill_tokens
            assert prefill_wrapper is not None
1205
1206

            if not attn_metadata.prefill_use_trtllm:
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
                if self.dcp_world_size > 1:
                    assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
                    assert prefill_wrapper._context._window_left == self.window_left
                    assert prefill_wrapper._context._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._context._sm_scale == self.scale
                    assert not prefill_wrapper._context._causal
                    assert prefill_wrapper._new_tokens._window_left == self.window_left
                    assert prefill_wrapper._new_tokens._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._new_tokens._sm_scale == self.scale
                    assert prefill_wrapper._new_tokens._causal

                    prefill_wrapper.run(
                        layer,
                        prefill_query,
                        kv_cache_permute,
                        key[num_decode_tokens:],
                        value[num_decode_tokens:],
                        out=output[num_decode_tokens:],
                    )
                else:
                    assert isinstance(
                        prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper
                    )
                    assert prefill_wrapper._window_left == self.window_left
                    assert prefill_wrapper._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._sm_scale == self.scale
                    assert prefill_wrapper._causal
                    prefill_wrapper.run(
                        prefill_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output[num_decode_tokens:],
                    )
1247
1248
1249
            else:
                # prefill_query may be non-contiguous
                prefill_query = prefill_query.contiguous()
1250
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
1251
                block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:]
1252
                seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
1253
1254
1255
1256
1257
1258
1259
1260
1261

                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                assert get_kv_cache_layout() == "HND"
                assert prefill_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_prefill.is_contiguous()
                assert seq_lens_prefill.is_contiguous()

1262
1263
                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
1264
1265
1266
1267
1268
1269
                    out = FP4Tensor(
                        data=output[num_decode_tokens:],
                        scale=output_block_scale,
                        scale_start_index=num_decode_tokens,
                        original_shape=prefill_query.shape,
                    )
1270
1271
1272
1273
                else:
                    assert self.o_sf_scale is None
                    out = output[num_decode_tokens:]

1274
1275
1276
1277
                if (
                    attn_metadata.q_data_type != FP8_DTYPE
                    and self.kv_cache_dtype.startswith("fp8")
                ):
1278
1279
1280
1281
                    # TRTLLM prefill attention does not support BF16 Q
                    # and fp8 kv cache. So to enable prefill attention
                    # with fp8 kv cache, we can construct a mock block
                    # and mock kv cache with BF16 KV involved in the prefill
1282
1283
1284
1285
1286
1287
1288
                    mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
                        kv_cache_permute,
                        block_tables_prefill,
                        layer._k_scale,
                        layer._v_scale,
                        attn_metadata.q_data_type,
                    )
1289
1290
1291
1292
                else:
                    mock_kv_cache = kv_cache_permute
                    mock_block_table = block_tables_prefill

1293
1294
                trtllm_batch_context_with_kv_cache(
                    query=prefill_query,
1295
                    kv_cache=mock_kv_cache,
1296
                    workspace_buffer=workspace_buffer,
1297
                    block_tables=mock_block_table,
1298
                    seq_lens=seq_lens_prefill,
1299
                    max_q_len=attn_metadata.max_q_len_prefill,
1300
                    max_kv_len=attn_metadata.max_seq_len,
1301
1302
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
1303
1304
1305
                    batch_size=attn_metadata.num_prefills,
                    cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
                    cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
1306
                    window_left=self.window_left,
1307
                    sinks=self.sinks,
1308
1309
                    o_sf_scale=self.o_sf_scale,
                    out=out,
1310
1311
1312
1313
                )

        if num_decode_tokens > 0:
            decode_wrapper = attn_metadata.decode_wrapper
1314
1315
            decode_query = query[:num_decode_tokens]
            assert decode_query.shape[0] == num_decode_tokens
1316
            assert decode_wrapper is not None
1317
1318

            if not attn_metadata.decode_use_trtllm:
1319
                assert decode_wrapper._window_left == self.window_left
1320
                assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
1321
                assert decode_wrapper._sm_scale == self.scale
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342

                if self.dcp_world_size > 1:
                    decode_query = get_dcp_group().all_gather(
                        decode_query.contiguous(), dim=-2
                    )
                    output_tmp = torch.empty_like(decode_query)
                    lse = torch.empty(
                        (decode_query.size(0), decode_query.size(1)),
                        dtype=torch.float32,
                        device=decode_query.device,
                    )
                    decode_wrapper.run(
                        decode_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output_tmp,
                        lse=lse,
                        return_lse=True,
                    )
                    output[:num_decode_tokens] = cp_lse_ag_out_rs(
1343
1344
1345
1346
                        output_tmp,
                        lse,
                        get_dcp_group(),
                        is_lse_base_on_e=False,
1347
1348
1349
1350
1351
1352
1353
1354
1355
                    )
                else:
                    decode_wrapper.run(
                        decode_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output[:num_decode_tokens],
                    )
1356
            else:
1357
1358
                # decode_query may be non-contiguous
                decode_query = decode_query.contiguous()
1359
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
1360
1361
1362
                block_tables_decode = attn_metadata.block_table_tensor[
                    :num_decode_tokens
                ]
1363
1364
                seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

1365
                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
1366
1367
1368
1369
1370
1371
1372
                assert get_kv_cache_layout() == "HND"
                assert decode_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_decode.is_contiguous()
                assert seq_lens_decode.is_contiguous()

1373
1374
                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
1375
1376
1377
1378
1379
1380
                    out = FP4Tensor(
                        data=output[:num_decode_tokens],
                        scale=output_block_scale,
                        scale_start_index=0,
                        original_shape=decode_query.shape,
                    )
1381
1382
1383
1384
                else:
                    assert self.o_sf_scale is None
                    out = output[:num_decode_tokens]

1385
1386
1387
1388
1389
                if num_decode_tokens % attn_metadata.num_decodes != 0:
                    # This gets triggered when the dummy_run forces
                    # attention to be initialized with q_len = 0
                    q_len_per_req = 1
                else:
1390
                    q_len_per_req = num_decode_tokens // attn_metadata.num_decodes
1391

1392
1393
1394
1395
1396
1397
1398
                trtllm_batch_decode_with_kv_cache(
                    query=decode_query,
                    kv_cache=kv_cache_permute,
                    workspace_buffer=workspace_buffer,
                    block_tables=block_tables_decode,
                    seq_lens=seq_lens_decode,
                    max_seq_len=attn_metadata.max_seq_len,
1399
1400
1401
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
                    window_left=self.window_left,
1402
                    sinks=self.sinks,
1403
1404
                    o_sf_scale=self.o_sf_scale,
                    out=out,
1405
1406
                    q_len_per_req=q_len_per_req,
                )
1407
        return output_padded
1408
1409
1410
1411
1412
1413
1414


def fast_plan_decode(
    self,  # decode wrapper
    indptr_cpu: torch.Tensor,
    indices: torch.Tensor,
    last_page_len_cpu: torch.Tensor,
1415
    seq_lens_cpu: torch.Tensor,
1416
1417
1418
1419
1420
1421
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    window_left: int = -1,
1422
    logits_soft_cap: float | None = None,
1423
1424
1425
    q_data_type: str | torch.dtype | None = "float16",
    kv_data_type: str | torch.dtype | None = None,
    data_type: str | torch.dtype | None = None,
1426
1427
1428
    sm_scale: float | None = None,
    rope_scale: float | None = None,
    rope_theta: float | None = None,
1429
    non_blocking: bool = True,
1430
1431
    fixed_split_size: int = -1,
    disable_split_kv: bool = False,
1432
1433
) -> None:
    """
1434
1435
    A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
    cudagraph capture/replay, while the no cudagraph version turns back
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
    to the original plan.
    using original plan after passing host-side buffers:
    - only host-to-device copy of indptr and last_page_len buffers
    Modifications for cudagraph:
    - only host-to-device copy of indptr and last_page_len buffers.
    - avoid device-to-device copy of indices buffer.

    Part of the code get inspiration from the original plan from FlashInfer repo
    and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
    """
    # Warm up with the original plan if it is first call, and always run the
    # original plan if we run for dynamic shape. For fixed shape (cudagraph),
    # this warm up is to generate the _cached_module for the decode wrapper.
1449
    if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
        self.plan(
            indptr_cpu,
            indices,
            last_page_len_cpu,
            num_qo_heads,
            num_kv_heads,
            head_dim,
            page_size,
            pos_encoding_mode,
            window_left,
            logits_soft_cap,
            q_data_type,
            kv_data_type,
            data_type,
            sm_scale,
            rope_scale,
            rope_theta,
            non_blocking,
1468
1469
1470
1471
            None,  # block_tables
            None,  # seq_lens
            fixed_split_size,
            disable_split_kv,
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
        )
        self.vllm_first_call = False
        return

    assert self.is_cuda_graph_enabled, "Should be cudagraph only here"

    batch_size = len(last_page_len_cpu)
    if logits_soft_cap is None:
        logits_soft_cap = 0.0

    # Handle data types consistently
    if data_type is not None:
        if q_data_type is None:
            q_data_type = data_type
        if kv_data_type is None:
            kv_data_type = data_type
    elif q_data_type is None:
        q_data_type = "float16"

    if kv_data_type is None:
        kv_data_type = q_data_type
1493
1494
1495
1496
1497
1498
    q_data_type = (
        getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
    )
    kv_data_type = (
        getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
    )
1499
1500
1501
1502
1503

    if batch_size != self._fixed_batch_size:
        raise ValueError(
            "The batch size should be fixed in cudagraph mode, the runtime "
            "batch size {} mismatches the batch size set during "
1504
1505
            "initialization {}".format(batch_size, self._fixed_batch_size)
        )
1506
1507
    if len(indices) > len(self._paged_kv_indices_buf):
        raise ValueError(
1508
1509
            "The size of indices should be less than or equal to the allocated buffer"
        )
1510
1511
1512
1513

    # host-to-device copy for the indptr buffer
    self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
    # host-to-device copy for the last_page_len buffer
1514
    self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True)
1515

1516
1517
1518
    qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")

    try:
1519
        # Make sure we pass exactly 19 arguments for tensor core version
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
        self._plan_info = self._cached_module.plan(
            self._float_workspace_buffer,
            self._int_workspace_buffer,
            self._pin_memory_int_workspace_buffer,
            qo_indptr_host,
            indptr_cpu,
            seq_lens_cpu,
            batch_size,  # total_num_rows
            batch_size,
            num_qo_heads,
            num_kv_heads,
            page_size,
            self.is_cuda_graph_enabled,
            head_dim,
            head_dim,
            False,  # causal
1536
            window_left,
1537
1538
            fixed_split_size,
            disable_split_kv,
1539
            0,
1540
1541
1542
        )
    except Exception as e:
        raise RuntimeError(f"Error in tensor core plan: {e}") from e
1543
1544
1545
1546
1547
1548
1549

    self._pos_encoding_mode = pos_encoding_mode
    self._window_left = window_left
    self._logits_soft_cap = logits_soft_cap
    self._sm_scale = sm_scale
    self._rope_scale = rope_scale
    self._rope_theta = rope_theta
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568


@triton.jit
def _copy_page_indices_kernel(
    page_indices,
    block_table,
    block_table_stride,
    cu_num_blocks,
    BLOCK_SIZE: tl.constexpr,
):
    req_idx = tl.program_id(0)
    row_ptr = block_table + req_idx * block_table_stride
    start_idx = tl.load(cu_num_blocks + req_idx)
    end_idx = tl.load(cu_num_blocks + req_idx + 1)
    num_blocks = end_idx - start_idx

    offset = tl.arange(0, BLOCK_SIZE)
    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
1569
1570
1571
1572
1573
        tl.store(
            page_indices + start_idx + i + offset,
            block_ids,
            mask=i + offset < num_blocks,
        )