flashinfer.py 42.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Attention layer with FlashInfer."""
from __future__ import annotations

from dataclasses import dataclass
7
from typing import ClassVar, Optional, Union
8
9
10
11
12

import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
                        BatchPrefillWithPagedKVCacheWrapper,
                        MultiLevelCascadeAttentionWrapper)
13
14
from flashinfer.decode import (_get_range_buf, get_seq_lens,
                               trtllm_batch_decode_with_kv_cache)
15
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
16
17
18
19

import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
                                              AttentionType)
20
from vllm.config import VllmConfig
21
from vllm.logger import init_logger
22
from vllm.utils import cdiv, is_pin_memory_available
23
from vllm.utils.flashinfer import use_trtllm_attention
24
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
25
26
27
28
29
30
31
32
33
# yapf conflicts with isort for this block
# yapf: disable
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
                                              AttentionMetadataBuilder,
                                              CommonAttentionMetadata,
                                              get_kv_cache_layout,
                                              get_per_layer_parameters,
                                              infer_global_hyperparameters,
                                              split_decodes_and_prefills)
34
from vllm.v1.kv_cache_interface import AttentionSpec
35
36
37
38
39
40
41
42
43
44

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024

logger = init_logger(__name__)


class FlashInferBackend(AttentionBackend):

    accept_output_buffer: bool = True

45
46
47
48
    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

49
50
51
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
52
53
        return [64, 128, 256]

54
55
56
57
58
59
60
61
62
63
64
    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    @staticmethod
    def get_name() -> str:
        return "FLASHINFER_VLLM_V1"

    @staticmethod
    def get_impl_cls() -> type[FlashInferImpl]:
        return FlashInferImpl

    @staticmethod
    def get_metadata_cls() -> type[FlashInferMetadata]:
        return FlashInferMetadata

    @staticmethod
    def get_builder_cls() -> type[FlashInferMetadataBuilder]:
        return FlashInferMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

90
91
92
93
94
95
96
97
98
99
100
101
102
    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
        # `stride_order` indicates the permutation that gets us from
        # `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
        if cache_layout == "NHD":
            stride_order = (0, 1, 2, 3, 4)
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

103
104
105
106
107
108
109
110
111
    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

112
113
114
115
116
117
118
119
120

@dataclass
class FlashInferMetadata:

    num_actual_tokens: int  # Number of tokens excluding padding.

    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
121
    qo_indptr_cpu: torch.Tensor
122
123
124
125
126
127
128
129
    # An example for paged_kv_indices, paged_kv_indptr:
    # request 1, page indices [0, 5, 8]
    # request 2, page indices [1, 6, 7]
    # request 3, page indices [3, 4]
    # paged_kv_indices is a concatenation of page indices of all requests:
    # [0, 5, 8, 1, 6, 7, 3, 4]
    # paged_kv_indptr is used to index into paged_kv_indices:
    # [0, 3, 6, 8]
130
131
132
    # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
    paged_kv_indptr_cpu: torch.Tensor
    # The page indices of the paged kv cache (on device for plan)
133
134
    paged_kv_indices: torch.Tensor
    # The number of entries in the last page of each request in
135
136
    # the paged kv cache, shape: [batch_size] (CPU for plan)
    paged_kv_last_page_len_cpu: torch.Tensor
137
138
139
140
141
142
143
144
145
    # The number of query/output heads
    num_qo_heads: int
    # The number of key/value heads
    num_kv_heads: int
    # The dimension of the attention heads
    head_dim: int
    # Block size of vllm
    page_size: int
    # The data type of the paged kv cache
146
    kv_data_type: torch.dtype
147
148
149
150
151
    # The data type of the query
    q_data_type: torch.dtype

    slot_mapping: torch.Tensor

152
    # For flashinfer trtllm batch decode
153
    max_q_len: int
154
155
156
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table_tensor: torch.Tensor
157
158
    prefill_use_trtllm: bool
    decode_use_trtllm: bool
159

160
161
162
163
164
165
    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

166
    # For cascade attention (CPU for planning).
167
    use_cascade: bool
168
169
170
171
    shared_qo_indptr_cpu: Optional[torch.Tensor] = None
    shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
    shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
    shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
172
173
174
175
176

    prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
    decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
    cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None

177
178
179
    qo_indptr_gpu: Optional[torch.Tensor] = None
    paged_kv_indptr_gpu: Optional[torch.Tensor] = None

180
    def __post_init__(self):
181
182
        if self.head_dim is not None:
            FlashInferBackend.validate_head_size(self.head_dim)
183
184


185
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
186
187
    attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
        AttentionCGSupport.PURE_DECODE_ONLY
188

189
190
    reorder_batch_threshold: ClassVar[int] = 1

191
192
    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
193
        self.device = device
194
195
196
        self.vllm_config = vllm_config
        self.cache_config = vllm_config.cache_config
        self.kv_cache_spec = kv_cache_spec
197
198
        self._workspace_buffer = None
        self._prefill_wrapper = None  # Wrapper for prefill/append
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        self._decode_wrapper = None  # Wrapper for decode (general shape)

        self.compilation_config = vllm_config.compilation_config
        max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
                                     self.kv_cache_spec.block_size)
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        max_num_pages = max_num_reqs * max_num_pages_per_req
        self.enable_cuda_graph = self.compilation_config.full_cuda_graph
        if self.enable_cuda_graph:
            # For full cudagraph capture, one `decode_wrapper` for each batch
            # size is needed for FlashInfer.
            self._decode_wrappers_cudagraph: dict[
                int, BatchDecodeWithPagedKVCacheWrapper] = {}
            self._decode_cudagraph_max_bs = min(
                max_num_reqs, self.compilation_config.max_capture_size)

215
216
217
        self._cascade_wrapper = None  # Wrapper for cascade attention

        # Global hyperparameters shared by all attention layers
218
219
        self.global_hyperparameters = infer_global_hyperparameters(
            get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        # Preparing persistent buffers (device-side)
        self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
                                           dtype=torch.int32,
                                           device=self.device)
        self.paged_kv_indices = torch.zeros(
            max_num_pages,  # max num pages possible
            dtype=torch.int32,
            device=self.device)
        self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
                                                  dtype=torch.int32,
                                                  device=self.device)
        # host-side buffer
        pin_memory = is_pin_memory_available()
        self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=pin_memory)
        self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
                                                dtype=torch.int32,
                                                device="cpu",
                                                pin_memory=pin_memory)
        self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs,
                                                      dtype=torch.int32,
                                                      device="cpu",
                                                      pin_memory=pin_memory)

        self.block_table_arange = torch.arange(max_num_pages_per_req,
248
249
                                               dtype=torch.int32,
                                               device=self.device)
250
251
252
253
254
255

    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
            self._workspace_buffer = torch.empty(
                FLASHINFER_WORKSPACE_BUFFER_SIZE,
                dtype=torch.uint8,
256
                device=self.device)
257
258
259
260
261
        return self._workspace_buffer

    def _get_prefill_wrapper(self):
        if self._prefill_wrapper is None:
            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
262
                self._get_workspace_buffer(), get_kv_cache_layout())
263
264
        return self._prefill_wrapper

265
266
267
268
269
270
271
272
273
274
    def _get_decode_wrapper(self,
                            batch_size: int,
                            use_cudagraph: bool = False):
        if use_cudagraph:
            decode_wrapper = self._decode_wrappers_cudagraph.get(
                batch_size, None)
        else:
            decode_wrapper = self._decode_wrapper

        if decode_wrapper is None:
275
276
277
278
279
            num_qo_heads = (
                self.vllm_config.model_config.get_num_attention_heads(
                    self.vllm_config.parallel_config))
            num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
                self.vllm_config.parallel_config)
280
281
            use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
                num_qo_heads // num_kv_heads > 4)
282
283
284
285
286
287
288
289
290
291
292

            if use_cudagraph:
                paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
                paged_kv_indices = self.paged_kv_indices
                paged_kv_last_page_len = self.paged_kv_last_page_len[:
                                                                     batch_size]
            else:
                paged_kv_indptr = None
                paged_kv_indices = None
                paged_kv_last_page_len = None
            decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
293
                self._get_workspace_buffer(),
294
                get_kv_cache_layout(),
295
296
297
298
                use_cuda_graph=use_cudagraph,
                paged_kv_indptr_buffer=paged_kv_indptr,
                paged_kv_indices_buffer=paged_kv_indices,
                paged_kv_last_page_len_buffer=paged_kv_last_page_len,
299
                use_tensor_cores=use_tensor_cores)
300
301
302
303
304
305
306
307

            # save the decode wrapper
            if use_cudagraph:
                self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
            else:
                self._decode_wrapper = decode_wrapper

        return decode_wrapper
308
309
310
311

    def _get_cascade_wrapper(self):
        if self._cascade_wrapper is None:
            self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
312
                2, self._get_workspace_buffer(), get_kv_cache_layout())
313
314
        return self._cascade_wrapper

315
    def _plan(self, attn_metadata: FlashInferMetadata):
316
317
318
319
        if attn_metadata.use_cascade:
            attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
            attn_metadata.cascade_wrapper.plan(
                [
320
321
322
323
324
325
                    attn_metadata.shared_qo_indptr_cpu,
                    attn_metadata.qo_indptr_cpu
                ],
                [
                    attn_metadata.shared_kv_page_indptr_cpu,
                    attn_metadata.paged_kv_indptr_cpu
326
327
                ],
                [
328
                    attn_metadata.shared_kv_page_indices_cpu,
329
330
331
                    attn_metadata.paged_kv_indices
                ],
                [
332
333
                    attn_metadata.shared_kv_last_page_len_cpu,
                    attn_metadata.paged_kv_last_page_len_cpu
334
335
336
337
338
339
340
341
342
343
                ],
                attn_metadata.num_qo_heads,
                attn_metadata.num_kv_heads,
                attn_metadata.head_dim,
                attn_metadata.page_size,
                causal=True,
                sm_scale=self.global_hyperparameters.sm_scale,
                window_left=self.global_hyperparameters.window_left,
                logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
                q_data_type=attn_metadata.q_data_type,
344
                kv_data_type=attn_metadata.kv_data_type,
345
346
347
348
349
            )
        else:
            # Regular attention (common case).
            # Decodes are at the front and prefills are at the back,
            # according to reorder_batch()
350
351
            num_prefills = attn_metadata.num_prefills
            num_decodes = attn_metadata.num_decodes
352
            if num_prefills > 0:
353
                # Decodes are first so prefills start after the last decode
354
                prefill_start = num_decodes
355
                attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
356
                assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
357
                    0] == num_prefills + 1
358
                assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
359
                    0] == num_prefills + 1
360
                assert attn_metadata.paged_kv_last_page_len_cpu[
361
                    prefill_start:].shape[0] == num_prefills
362
363
364
                # Since prefill_wrapper.run() will be called with
                # query[num_decode_tokens:] we need to adjust the qo_indptr
                # to be relative to the start of the prefill queries.
365
366
                qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
                    prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
                paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[
                    prefill_start:]
                if not attn_metadata.prefill_use_trtllm:
                    attn_metadata.prefill_wrapper.plan(
                        qo_indptr_cpu,
                        paged_kv_indptr_cpu,
                        attn_metadata.paged_kv_indices,
                        attn_metadata.
                        paged_kv_last_page_len_cpu[prefill_start:],
                        attn_metadata.num_qo_heads,
                        attn_metadata.num_kv_heads,
                        attn_metadata.head_dim,
                        attn_metadata.page_size,
                        causal=True,
                        sm_scale=self.global_hyperparameters.sm_scale,
                        window_left=self.global_hyperparameters.window_left,
                        logits_soft_cap=self.global_hyperparameters.
                        logits_soft_cap,
                        q_data_type=attn_metadata.q_data_type,
                        kv_data_type=attn_metadata.kv_data_type,
                    )
                else:
                    attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
                    attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
                        self.device)
392

393
            if num_decodes > 0:
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
                pure_decode = num_prefills == 0
                # possible required padding for cudagraph replay
                use_cudagraph = (self.enable_cuda_graph and pure_decode and
                                 num_decodes <= self._decode_cudagraph_max_bs)
                if use_cudagraph:
                    num_input_tokens = (
                        self.vllm_config.pad_for_cudagraph(num_decodes))
                    # Carefully fulfill the padding region with reasonable value
                    # on cpu.
                    # Make sure paged_kv_indptr_cpu is not decreasing
                    self.paged_kv_indptr_cpu[1 + num_decodes:1 +
                                             num_input_tokens].fill_(
                                                 attn_metadata.
                                                 paged_kv_indptr_cpu[-1])
                    # Fill the remaining paged_kv_last_page_len_cpu with 1.
                    # This is because flashinfer treats 0 as a full page
                    # instead of empty.
                    self.paged_kv_last_page_len_cpu[
                        num_decodes:num_input_tokens].fill_(1)

                else:
                    num_input_tokens = num_decodes

                attn_metadata.decode_wrapper = self._get_decode_wrapper(
                    num_input_tokens, use_cudagraph)
419
                if not attn_metadata.decode_use_trtllm:
420
421
422
423
424
425
                    # Use the persistent buffer with padding length,
                    # instead of the same address but chunked version
                    # in atten_metadata when using cudagraph.
                    fast_plan_decode(
                        attn_metadata.decode_wrapper,
                        self.paged_kv_indptr_cpu[:num_input_tokens + 1],
426
                        attn_metadata.paged_kv_indices,
427
                        self.paged_kv_last_page_len_cpu[:num_input_tokens],
428
429
430
431
432
433
434
435
436
437
438
439
440
                        attn_metadata.num_qo_heads,
                        attn_metadata.num_kv_heads,
                        attn_metadata.head_dim,
                        attn_metadata.page_size,
                        # Disable flashinfer's pos encoding and use vllm's rope.
                        pos_encoding_mode="NONE",
                        sm_scale=self.global_hyperparameters.sm_scale,
                        window_left=self.global_hyperparameters.window_left,
                        logits_soft_cap=self.global_hyperparameters.
                        logits_soft_cap,
                        q_data_type=attn_metadata.q_data_type,
                        kv_data_type=attn_metadata.kv_data_type,
                    )
441

442
443
444
445
    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> FlashInferMetadata:
446
        num_reqs = common_attn_metadata.num_reqs
447
        num_actual_tokens = common_attn_metadata.num_actual_tokens
448
449
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
            split_decodes_and_prefills(common_attn_metadata)
450

451
        page_size = self.kv_cache_spec.block_size
452
        max_q_len = common_attn_metadata.max_query_len
453
        max_seq_len = common_attn_metadata.seq_lens_cpu.max()
454
        seq_lens = common_attn_metadata.seq_lens
455
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
456
        block_table_tensor = common_attn_metadata.block_table_tensor
457

458
        block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size
459
460
461
462
463
464

        use_cascade = common_prefix_len > 0
        if use_cascade:
            # Grab the blocks of the shared prefix from the first request.
            assert common_prefix_len % page_size == 0
            num_common_kv_blocks = common_prefix_len // page_size
465
466
467
468
469
470
471
472
473

            # Create CPU versions directly for cascade (no GPU versions needed)
            shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
                                                device='cpu')
            shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks],
                                                     dtype=torch.int32,
                                                     device='cpu')
            shared_kv_page_indices_cpu = block_table_tensor[
474
                0, :num_common_kv_blocks]
475
476
477
478
            shared_kv_last_page_len_cpu = torch.tensor([page_size],
                                                       dtype=torch.int32,
                                                       device='cpu')

479
            # Remove the blocks of the shared prefix from all requests.
480
            block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
481
            block_table_bounds_cpu -= num_common_kv_blocks
482
        else:
483
484
485
486
487
488
489
490
491
            shared_qo_indptr_cpu = None
            shared_kv_page_indptr_cpu = None
            shared_kv_page_indices_cpu = None
            shared_kv_last_page_len_cpu = None

        max_num_blocks = block_table_bounds_cpu.max()
        block_table_bounds = block_table_bounds_cpu.to(self.device,
                                                       non_blocking=True)
        mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
492
                < block_table_bounds.unsqueeze(1))
493
494
495
496
497
498
499
500
501
502
503
504
        # write self.paged_kv_indices inplace
        num_actual_pages = torch.sum(mask)
        paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
        torch.masked_select(block_table_tensor[:, :max_num_blocks],
                            mask,
                            out=paged_kv_indices)

        # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
        torch.cumsum(block_table_bounds_cpu,
                     dim=0,
                     dtype=torch.int32,
                     out=self.paged_kv_indptr_cpu[1:1 + num_reqs])
505
506

        paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
507
508
509
510
511
512
        # write self.paged_kv_last_page_len_cpu inplace
        torch.where(paged_kv_last_page_len_cpu == 0,
                    torch.tensor(page_size),
                    paged_kv_last_page_len_cpu,
                    out=self.paged_kv_last_page_len_cpu[:num_reqs])

513
        cache_dtype = self.cache_config.cache_dtype
514
515
516
517
518
        if cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                cache_dtype)
        else:
            kv_cache_dtype = self.kv_cache_spec.dtype
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536

        num_qo_heads = self.vllm_config.model_config.get_num_attention_heads(
            self.vllm_config.parallel_config)
        num_kv_heads = self.kv_cache_spec.num_kv_heads
        head_dim = self.kv_cache_spec.head_size

        # currently prefill trtllm attention does not support fp8 kv cache
        # trtllm may not support sliding window
        prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
                              and not cache_dtype.startswith("fp8")
                              and use_trtllm_attention(
                                num_prefill_tokens, max_seq_len, cache_dtype,
                                num_qo_heads, num_kv_heads, head_dim))
        decode_use_trtllm = (self.global_hyperparameters.window_left == -1
                             and use_trtllm_attention(
                                num_decode_tokens, max_seq_len, cache_dtype,
                                num_qo_heads, num_kv_heads, head_dim))

537
538
        attn_metadata = FlashInferMetadata(
            num_actual_tokens=num_actual_tokens,
539
            qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
540
            paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs],
541
            paged_kv_indices=paged_kv_indices,
542
543
            paged_kv_last_page_len_cpu=self.
            paged_kv_last_page_len_cpu[:num_reqs],
544
545
546
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim=head_dim,
547
            page_size=page_size,
548
            kv_data_type=kv_cache_dtype,
549
550
            q_data_type=self.vllm_config.model_config.dtype,
            slot_mapping=common_attn_metadata.slot_mapping,
551
552
553
554
555
556
            max_q_len=max_q_len,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
            block_table_tensor=block_table_tensor,
            prefill_use_trtllm=prefill_use_trtllm,
            decode_use_trtllm=decode_use_trtllm,
557
558
559
560
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
561
            use_cascade=use_cascade,
562
563
564
565
            shared_qo_indptr_cpu=shared_qo_indptr_cpu,
            shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
            shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
            shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
566
567
        )

568
        self._plan(attn_metadata)
569
570
571

        return attn_metadata

572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    def build_for_cudagraph_capture(
            self, common_attn_metadata: CommonAttentionMetadata):
        """
        This method builds the metadata for full cudagraph capture.
        Currently, only decode is supported for full cudagraphs with FlashInfer.
        """
        m = common_attn_metadata

        assert m.num_reqs == m.num_actual_tokens, \
            "FlashInfer only supports decode-only full CUDAGraph capture. " \
            "Make sure all cudagraph capture sizes <= max_num_seq."

        m.max_query_len = 1  # decode-only

        return self.build(0, m)

    def can_run_in_cudagraph(
            self, common_attn_metadata: CommonAttentionMetadata) -> bool:
        return common_attn_metadata.max_query_len == 1

592
    def use_cascade_attention(self, *args, **kwargs) -> bool:
593
        if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
            # TODO: The cascade wrapper currently does not support setting
            # kv cache dtype to something different from query dtype.
            return False
        return use_cascade_attention(*args, **kwargs)


class FlashInferImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
        attn_type: AttentionType = AttentionType.DECODER,
613
        kv_sharing_target_layer_name: Optional[int] = None,
614
        sinks: Optional[torch.Tensor] = None,
615
616
617
618
619
620
621
622
623
624
625
626
627
628
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
629
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
630
631
632
633
634
635
636
637
638

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashInferImpl")

639
640
641
642
643
644
645
646
647
        self.sinks: Optional[torch.Tensor] = None
        if sinks is not None:
            assert sinks.shape[0] == num_heads, (
                "Sinks must have the same number of heads "
                "as the number of heads in the layer"
            )
            assert sinks.dtype == torch.float32, "Sinks must be of type float32"
            self.sinks = sinks

648
649
650
651
652
653
654
655
656
    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashInferMetadata,
        output: Optional[torch.Tensor] = None,
657
        output_scale: Optional[torch.Tensor] = None,
658
659
660
661
662
663
664
    ) -> torch.Tensor:
        """Forward pass with FlashInfer.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
665
            kv_cache: shape -
666
667
            # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
            # HND: [num_blocks, 2,  num_kv_heads, block_size, head_size]
668

669

670
671
672
673
674
675
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

676
677
678
679
680
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashInferImpl")

681
682
683
684
685
686
687
688
689
690
691
692
693
694
        if attn_metadata is None:
            # Profiling run.
            return output

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713

        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
714

715
716
717
718
719
720
721
            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if self.kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                    self.kv_cache_dtype)
                kv_cache = kv_cache.view(torch_dtype)

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
        window_left = (self.sliding_window[0]
                       if self.sliding_window is not None else -1)

        # Inputs and outputs may be padded for CUDA graphs
        query = query[:num_actual_tokens]
        output_padded = output
        output = output[:num_actual_tokens]

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

        num_decode_tokens = attn_metadata.num_decode_tokens
        num_prefill_tokens = attn_metadata.num_prefill_tokens

739
        stride_order = FlashInferBackend.get_kv_cache_stride_order()
740
        kv_cache_permute = kv_cache.permute(*stride_order)
741
742
743
        # Regular attention (common case).
        # Decodes are at the front and prefills are at the back,
        # according to reorder_batch()
744
745
        if num_prefill_tokens > 0:
            prefill_wrapper = attn_metadata.prefill_wrapper
746
747
748
            prefill_query = query[num_decode_tokens:]
            assert prefill_query.shape[0] == num_prefill_tokens
            assert prefill_wrapper is not None
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796

            if not attn_metadata.prefill_use_trtllm:
                assert prefill_wrapper._causal
                assert prefill_wrapper._window_left == window_left
                assert prefill_wrapper._logits_soft_cap == (
                    self.logits_soft_cap or 0.0)
                assert prefill_wrapper._sm_scale == self.scale
                prefill_wrapper.run(
                    prefill_query,
                    kv_cache_permute,
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output[num_decode_tokens:],
                )
            else:
                # prefill_query may be non-contiguous
                prefill_query = prefill_query.contiguous()
                workspace_buffer = prefill_wrapper._float_workspace_buffer
                block_tables_prefill = attn_metadata.block_table_tensor[
                    num_decode_tokens:]
                seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]

                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                assert get_kv_cache_layout() == "HND"
                assert prefill_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_prefill.is_contiguous()
                assert seq_lens_prefill.is_contiguous()

                trtllm_batch_context_with_kv_cache(
                    query=prefill_query,
                    kv_cache=kv_cache_permute,
                    workspace_buffer=workspace_buffer,
                    block_tables=block_tables_prefill,
                    seq_lens=seq_lens_prefill,
                    max_q_len=attn_metadata.max_q_len,
                    max_kv_len=attn_metadata.max_seq_len,
                    bmm1_scale=layer._k_scale_float * self.scale,
                    bmm2_scale=layer._v_scale_float,
                    batch_size=attn_metadata.num_prefills,
                    cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
                    cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
                    out=output[num_decode_tokens:],
                )

        if num_decode_tokens > 0:
            decode_wrapper = attn_metadata.decode_wrapper
797
798
            decode_query = query[:num_decode_tokens]
            assert decode_query.shape[0] == num_decode_tokens
799
            assert decode_wrapper is not None
800
801

            if not attn_metadata.decode_use_trtllm:
802
803
804
805
806
807
                assert decode_wrapper._window_left == window_left
                assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
                                                           or 0.0)
                assert decode_wrapper._sm_scale == self.scale
                decode_wrapper.run(
                    decode_query,
808
                    kv_cache_permute,
809
810
811
812
813
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output[:num_decode_tokens],
                )
            else:
814
815
816
817
818
819
820
                # decode_query may be non-contiguous
                decode_query = decode_query.contiguous()
                workspace_buffer = decode_wrapper._float_workspace_buffer
                block_tables_decode = attn_metadata.block_table_tensor[:
                                                                       num_decode_tokens]
                seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

821
                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
                assert get_kv_cache_layout() == "HND"
                assert decode_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_decode.is_contiguous()
                assert seq_lens_decode.is_contiguous()

                trtllm_batch_decode_with_kv_cache(
                    query=decode_query,
                    kv_cache=kv_cache_permute,
                    workspace_buffer=workspace_buffer,
                    block_tables=block_tables_decode,
                    seq_lens=seq_lens_decode,
                    max_seq_len=attn_metadata.max_seq_len,
                    bmm1_scale=layer._k_scale_float * self.scale,
                    bmm2_scale=layer._v_scale_float,
                    out=output[:num_decode_tokens],
                )
840
        return output_padded
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863


def fast_plan_decode(
    self,  # decode wrapper
    indptr_cpu: torch.Tensor,
    indices: torch.Tensor,
    last_page_len_cpu: torch.Tensor,
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    window_left: int = -1,
    logits_soft_cap: Optional[float] = None,
    q_data_type: Optional[Union[str, torch.dtype]] = "float16",
    kv_data_type: Optional[Union[str, torch.dtype]] = None,
    data_type: Optional[Union[str, torch.dtype]] = None,
    sm_scale: Optional[float] = None,
    rope_scale: Optional[float] = None,
    rope_theta: Optional[float] = None,
    non_blocking: bool = True,
) -> None:
    """
864
865
    A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
    cudagraph capture/replay, while the no cudagraph version turns back
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
    to the original plan.
    using original plan after passing host-side buffers:
    - only host-to-device copy of indptr and last_page_len buffers
    Modifications for cudagraph:
    - only host-to-device copy of indptr and last_page_len buffers.
    - avoid device-to-device copy of indices buffer.

    Part of the code get inspiration from the original plan from FlashInfer repo
    and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
    """
    # Warm up with the original plan if it is first call, and always run the
    # original plan if we run for dynamic shape. For fixed shape (cudagraph),
    # this warm up is to generate the _cached_module for the decode wrapper.
    if not self.is_cuda_graph_enabled or \
        getattr(self, "vllm_first_call", True):
        self.plan(
            indptr_cpu,
            indices,
            last_page_len_cpu,
            num_qo_heads,
            num_kv_heads,
            head_dim,
            page_size,
            pos_encoding_mode,
            window_left,
            logits_soft_cap,
            q_data_type,
            kv_data_type,
            data_type,
            sm_scale,
            rope_scale,
            rope_theta,
            non_blocking,
        )
        self.vllm_first_call = False
        return

    assert self.is_cuda_graph_enabled, "Should be cudagraph only here"

    batch_size = len(last_page_len_cpu)
    if logits_soft_cap is None:
        logits_soft_cap = 0.0

    # Handle data types consistently
    if data_type is not None:
        if q_data_type is None:
            q_data_type = data_type
        if kv_data_type is None:
            kv_data_type = data_type
    elif q_data_type is None:
        q_data_type = "float16"

    if kv_data_type is None:
        kv_data_type = q_data_type
    q_data_type = getattr(torch, q_data_type) if isinstance(
        q_data_type, str) else q_data_type
    kv_data_type = getattr(torch, kv_data_type) if isinstance(
        kv_data_type, str) else kv_data_type

    if self.use_tensor_cores:
        qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")

    if batch_size != self._fixed_batch_size:
        raise ValueError(
            "The batch size should be fixed in cudagraph mode, the runtime "
            "batch size {} mismatches the batch size set during "
            "initialization {}".format(batch_size, self._fixed_batch_size))
    if len(indices) > len(self._paged_kv_indices_buf):
        raise ValueError(
            "The size of indices should be less than or equal to the "
            "allocated buffer")

    # host-to-device copy for the indptr buffer
    self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
    # host-to-device copy for the last_page_len buffer
    self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
                                           non_blocking=True)

    indptr_host = indptr_cpu
    last_page_len_host = last_page_len_cpu

    if self.use_tensor_cores:
        kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host,
                                        page_size)

        try:
            # Make sure we pass exactly 15 arguments for tensor core version
            self._plan_info = self._cached_module.plan(
                self._float_workspace_buffer,
                self._int_workspace_buffer,
                self._pin_memory_int_workspace_buffer,
                qo_indptr_host,
                indptr_host,
                kv_lens_arr_host,
                batch_size,  # total_num_rows
                batch_size,
                num_qo_heads,
                num_kv_heads,
                page_size,
                self.is_cuda_graph_enabled,
                head_dim,
                head_dim,
                False,  # causal
            )
        except Exception as e:
            raise RuntimeError(f"Error in tensor core plan: {e}") from e
    else:
        try:
            # Make sure we pass exactly 15 arguments for standard version
            self._plan_info = self._cached_module.plan(
                self._float_workspace_buffer,
                self._int_workspace_buffer,
                self._pin_memory_int_workspace_buffer,
                indptr_host,
                batch_size,
                num_qo_heads,
                num_kv_heads,
                page_size,
                self.is_cuda_graph_enabled,
                window_left,
                logits_soft_cap,
                head_dim,
                head_dim,
                torch.empty(0, dtype=q_data_type),
                torch.empty(0, dtype=kv_data_type),
            )
        except Exception as e:
            raise RuntimeError(f"Error in standard plan: {e}") from e

    self._pos_encoding_mode = pos_encoding_mode
    self._window_left = window_left
    self._logits_soft_cap = logits_soft_cap
    self._sm_scale = sm_scale
    self._rope_scale = rope_scale
    self._rope_theta = rope_theta