flashinfer.py 51.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
9

10
11
from vllm.multimodal import MultiModalPlaceholderMap

12
13
try:
    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
14
15
    from flashinfer.decode import (CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
                                   trtllm_batch_decode_with_kv_cache)
16
    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
17

18
    from vllm.vllm_flash_attn import flash_attn_varlen_func
19
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
20
except ImportError:
21
22
23
24
25
    # Avoid turning these types into variables during type checking
    if not TYPE_CHECKING:
        BatchDecodeWithPagedKVCacheWrapper = None
        CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
        BatchPrefillWithPagedKVCacheWrapper = None
26
        trtllm_batch_decode_with_kv_cache = None
27
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
28
29
    raise ImportError("FlashInfer is not installed. Please install it from "
                      "https://github.com/flashinfer-ai/flashinfer") from None
30

31
32
import torch

33
import vllm.envs as envs
34
35
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
36
                                              AttentionLayer,
37
38
                                              AttentionMetadata,
                                              AttentionMetadataBuilder,
39
                                              AttentionState, AttentionType)
40
41
42
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
                                           compute_slot_mapping_start_idx,
                                           is_block_tables_empty)
43
from vllm.attention.layer import Attention
44
from vllm.attention.ops.paged_attn import PagedAttention
45
from vllm.config import VllmConfig, get_layers_from_vllm_config
46
from vllm.logger import init_logger
47
from vllm.platforms import current_platform
48
49
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
                        make_tensor_with_pad)
50

51
52
logger = init_logger(__name__)

53
if TYPE_CHECKING:
54
55
    from vllm.worker.model_runner import (ModelInputForGPUBuilder,
                                          ModelInputForGPUWithSamplingMetadata)
56
57
58


class FlashInferBackend(AttentionBackend):
59
    cached_sm100a_supported: Optional[bool] = None
60

61
62
    @staticmethod
    def get_name() -> str:
63
        return "FLASHINFER"
64

65
66
67
68
69
    @staticmethod
    def get_impl_cls() -> Type["FlashInferImpl"]:
        return FlashInferImpl

    @staticmethod
70
71
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return FlashInferMetadata
72

73
74
75
76
    @staticmethod
    def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
        return FlashInferMetadataBuilder

77
78
79
80
    @staticmethod
    def get_state_cls() -> Type["FlashInferState"]:
        return FlashInferState

81
82
83
84
85
86
87
88
89
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

90
91
    @staticmethod
    def get_kv_cache_stride_order() -> Tuple[int, ...]:
92
        cache_layout = FlashInferState.get_kv_cache_layout()
93
94
95
96
97
        assert (cache_layout in ("NHD", "HND"))
        stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
                                                                      2, 4)
        return stride_order

98
99
100
101
    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
102
        src_to_dst: torch.Tensor,
103
    ) -> None:
104
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
105
106
107
108

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
109
        src_to_dists: torch.Tensor,
110
    ) -> None:
111
        PagedAttention.copy_blocks(kv_caches, src_to_dists)
112
113
114
115
116

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [64, 128, 256]

117
118
119
120
121
122
123
124
125
    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    @staticmethod
    def use_trtllm_decode_attention(
        batch_size: int,
        max_seq_len: int,
        kv_cache_dtype: str,
        num_qo_heads: Optional[int],
        num_kv_heads: Optional[int],
        attn_head_size: Optional[int],
    ) -> bool:
        if FlashInferBackend.cached_sm100a_supported is None:
            FlashInferBackend.cached_sm100a_supported = (
                current_platform.has_device_capability(100))
        if not FlashInferBackend.cached_sm100a_supported:
            return False
        # Check if the dimensions are supported by TRTLLM decode attention
        if (attn_head_size is None or num_qo_heads is None
                or num_kv_heads is None or num_qo_heads // num_kv_heads > 8
                or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
            return False
        env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
        if env_value is not None:
            logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
                             env_value)
            # Environment variable is set - respect it
            # Making the conditional check for zero because
            # the path is automatically enabled if the batch size condition
            # is satisfied.
            no_use_trtllm = (env_value == "0")
            if not no_use_trtllm:
                logger.info_once("Using TRTLLM decode attention.")
            return not no_use_trtllm
        else:
            # Environment variable not set - use auto-detection
            use_trtllm = (FlashInferBackend.cached_sm100a_supported
                          and batch_size <= 256 and max_seq_len < 131072
                          and kv_cache_dtype == "auto")
            if use_trtllm:
                logger.warning_once(
                    "Using TRTLLM decode attention (auto-detected).")
        return use_trtllm

167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@dataclass
class PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
    the same values for the following hyperparameters.
    """

    window_left: int
    logits_soft_cap: Optional[float]
    sm_scale: float


def get_per_layer_parameters(
        vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
    """
    Scan all attention layers and determine some hyperparameters
    to use during `plan`.
    """

187
    layers = get_layers_from_vllm_config(vllm_config, Attention)
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    per_layer_params: Dict[str, PerLayerParameters] = {}

    for key, layer in layers.items():
        impl = layer.impl
        assert isinstance(impl, FlashInferImpl)

        # Infer hyperparameters from the attention layer
        window_size = impl.sliding_window
        window_left = window_size[0] if window_size is not None else -1
        logits_soft_cap = impl.logits_soft_cap
        sm_scale = impl.scale

        per_layer_params[key] = PerLayerParameters(window_left,
                                                   logits_soft_cap, sm_scale)

    return per_layer_params


def infer_global_hyperparameters(
        per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
    the same values for the following hyperparameters:
    - `window_left`
    - `logits_soft_cap`
    - `sm_scale`

    So this function asserts that all layers share the same values for these
    hyperparameters and returns the global values.
    """

    assert len(per_layer_params) > 0, "No attention layers found in the model."

    param_sets = list(per_layer_params.values())
    global_params = param_sets[0]
    for params in param_sets:
        assert params == global_params, (
            "FlashInfer backend currently only supports models in which all "
            "layers share the same values for the following hyperparameters: "
            "`window_left`, `logits_soft_cap`, `sm_scale`.")

    return global_params


232
233
234
235
236
237
238
239
240
class FlashInferState(AttentionState):

    def __init__(self, runner):
        self.runner = runner
        self._is_graph_capturing = False
        self._workspace_buffer = None
        self._decode_wrapper = None
        self._prefill_wrapper = None

241
242
243
        # Global hyperparameters shared by all attention layers
        self.global_hyperparameters: Optional[PerLayerParameters] = None

244
        self.vllm_config = self.runner.vllm_config
245
        self._kv_cache_layout = None
246

247
248
249
250
251
252
253
254
    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
            self._workspace_buffer = torch.empty(
                FLASHINFER_WORKSPACE_BUFFER_SIZE,
                dtype=torch.uint8,
                device=self.runner.device)
        return self._workspace_buffer

255
256
257
258
259
260
261
262
263
264
265
266
267
    @staticmethod
    def get_kv_cache_layout():
        from vllm.v1.attention.backends.utils import _KV_CACHE_LAYOUT_OVERRIDE
        if _KV_CACHE_LAYOUT_OVERRIDE is not None:
            logger.info_once("Using KV cache layout %s",
                             _KV_CACHE_LAYOUT_OVERRIDE)
            return _KV_CACHE_LAYOUT_OVERRIDE
        cache_layout = envs.VLLM_KV_CACHE_LAYOUT
        if cache_layout is None:
            logger.info_once("Using default KV cache layout NHD")
            return "NHD"
        logger.info_once("Using KV cache layout %s", cache_layout)
        return cache_layout
268

269
270
271
    def _get_prefill_wrapper(self):
        if self._prefill_wrapper is None:
            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
272
                self._get_workspace_buffer(), self.get_kv_cache_layout())
273
274
275
276
277
278
279
280
        return self._prefill_wrapper

    def _get_decode_wrapper(self):
        if self._decode_wrapper is None:
            num_qo_heads = (self.runner.model_config.get_num_attention_heads(
                self.runner.parallel_config))
            num_kv_heads = self.runner.model_config.get_num_kv_heads(
                self.runner.parallel_config)
281
282
            use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
                num_qo_heads // num_kv_heads > 4)
283
284
            self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
                self._get_workspace_buffer(),
285
                self.get_kv_cache_layout(),
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
                use_tensor_cores=use_tensor_cores)
        return self._decode_wrapper

    @contextmanager
    def graph_capture(self, max_batch_size: int):
        self._is_graph_capturing = True
        self._graph_decode_wrapper = None
        self._graph_slot_mapping = torch.full((max_batch_size, ),
                                              PAD_SLOT_ID,
                                              dtype=torch.long,
                                              device=self.runner.device)
        self._graph_seq_lens = torch.ones(max_batch_size,
                                          dtype=torch.int32,
                                          device=self.runner.device)
        self._graph_block_tables = torch.from_numpy(
            self.runner.graph_block_tables).to(device=self.runner.device)
        self._graph_decode_workspace_buffer = self._get_workspace_buffer()
        self._graph_indices_buffer = torch.empty(
            max_batch_size * self.runner.cache_config.num_gpu_blocks,
            dtype=torch.int32,
            device=self.runner.device)
        self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
                                                dtype=torch.int32,
                                                device=self.runner.device)
        self._graph_last_page_len_buffer = torch.empty(
            max_batch_size, dtype=torch.int32, device=self.runner.device)
        yield
        self._is_graph_capturing = False
        del self._graph_slot_mapping
        del self._graph_seq_lens
        del self._graph_block_tables
        del self._graph_decode_workspace_buffer
        del self._graph_indices_buffer
        del self._graph_indptr_buffer
        del self._graph_last_page_len_buffer
        del self._graph_decode_wrapper

    def graph_clone(self, batch_size: int):
        assert self._is_graph_capturing
        state = self.__class__(self.runner)
        state._workspace_buffer = self._graph_decode_workspace_buffer
        state._decode_wrapper = self._graph_decode_wrapper
        state._prefill_wrapper = self._get_prefill_wrapper()
        return state

331
332
    def graph_capture_get_metadata_for_batch(
            self, batch_size: int, is_encoder_decoder_model: bool = False):
333
334
335
336
337
338
339
340
        assert self._is_graph_capturing
        _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
        _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]

        num_qo_heads = (self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config))
        num_kv_heads = self.runner.model_config.get_num_kv_heads(
            self.runner.parallel_config)
341
342
        use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
            num_qo_heads // num_kv_heads > 4)
343
344
345
        self._graph_decode_wrapper = \
            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
            self._graph_decode_workspace_buffer, _indptr_buffer,
346
347
            self._graph_indices_buffer, _last_page_len_buffer,
            self.get_kv_cache_layout(),
348
            use_tensor_cores)
349
350
351
352
353
354
        if self.runner.kv_cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.runner.kv_cache_dtype)
        else:
            kv_cache_dtype = get_kv_cache_torch_dtype(
                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
355
356
357
358
359
360
361
362
363
364
365
366
367
368

        paged_kv_indptr_tensor_host = torch.arange(0,
                                                   batch_size + 1,
                                                   dtype=torch.int32)
        paged_kv_indices_tensor_host = torch.arange(0,
                                                    batch_size,
                                                    dtype=torch.int32)
        paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
                                                        self.runner.block_size,
                                                        dtype=torch.int32)
        query_start_loc_host = torch.arange(0,
                                            batch_size + 1,
                                            dtype=torch.int32)

369
370
371
        global_params = infer_global_hyperparameters(
            get_per_layer_parameters(self.vllm_config))

372
373
374
        attn_metadata = self.runner.attn_backend.make_metadata(
            num_prefills=0,
            slot_mapping=self._graph_slot_mapping[:batch_size],
375
            multi_modal_placeholder_index_maps=None,
376
            enable_kv_scales_calculation=False,
377
378
379
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            max_prefill_seq_len=0,
380
381
            max_decode_seq_len=0,
            seq_lens_tensor=self._graph_seq_lens,
382
383
384
385
386
387
388
389
390
391
392
393
            block_tables=self._graph_block_tables,
            paged_kv_indptr=paged_kv_indptr_tensor_host,
            paged_kv_indices=paged_kv_indices_tensor_host,
            paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim=self.runner.model_config.get_head_size(),
            page_size=self.runner.block_size,
            seq_start_loc=None,
            query_start_loc=query_start_loc_host,
            device=self.runner.device,
            data_type=kv_cache_dtype,
394
            q_data_type=self.runner.model_config.dtype,
395
396
            use_cuda_graph=True,
            decode_wrapper=self._graph_decode_wrapper,
397
398
399
            prefill_wrapper=None,
            **dataclasses.asdict(global_params),
        )
400
401
402
        attn_metadata.begin_forward()
        return attn_metadata

403
404
405
    def get_graph_input_buffers(self,
                                attn_metadata,
                                is_encoder_decoder_model: bool = False):
406
        return {
407
408
            "block_tables": attn_metadata.block_tables,
            "seq_lens_tensor": attn_metadata.seq_lens_tensor,
409
410
411
            "slot_mapping": attn_metadata.slot_mapping,
        }

412
413
414
415
    def prepare_graph_input_buffers(self,
                                    input_buffers,
                                    attn_metadata,
                                    is_encoder_decoder_model: bool = False):
416
417
418
419
420
421
422
        # FlashInfer-specific logic: copy additional tensors
        num_total_blocks = attn_metadata.decode_metadata.seq_lens_tensor.shape[
            0]
        input_buffers["seq_lens_tensor"][:num_total_blocks].copy_(
            attn_metadata.seq_lens_tensor, non_blocking=True)
        input_buffers["block_tables"][:num_total_blocks].copy_(
            attn_metadata.block_tables, non_blocking=True)
423
424
425
426

    def begin_forward(self, model_input):
        assert not self._is_graph_capturing
        state = self
427
428
429
430
431
432
        use_cuda_graph = model_input.attn_metadata.use_cuda_graph
        is_decode = model_input.attn_metadata.num_prefills == 0
        # In case of multistep chunked-prefill, there might be prefill requests
        # scheduled while CUDA graph mode is enabled. We don't run graph in that
        # case.
        if use_cuda_graph and is_decode:
433
434
435
436
437
438
439
440
441
442
443
            if model_input.inputs_embeds is None:
                batch_size = model_input.input_tokens.shape[0]
                state = (
                    self.runner.graph_runners[model_input.virtual_engine][(
                        batch_size, False)].attn_state)
            else:
                batch_size = model_input.inputs_embeds.shape[0]
                state = (
                    self.runner.graph_runners[model_input.virtual_engine][(
                        batch_size, True)].attn_state)

444
445
446
447
448
449
        model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
        )
        model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
        model_input.attn_metadata.begin_forward()


450
@dataclass
451
452
453
454
class FlashInferMetadata(AttentionMetadata):
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
455
456
    max_decode_seq_len: int

457
458
459
460
461
    # Number of query tokens for each request in the batch.
    # Currently, we require that all requests have the same number of query
    # tokens during the decoding phase. When speculavie decoding is enabled,
    # decode_query_len might be greater than 1. In all other cases, it is 1.
    decode_query_len: Optional[int] = 1
462

463
    use_cuda_graph: bool = True
464

465
    prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
466
467
    decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None

468
    # Metadata for the prefill stage
469
    seq_start_loc: Optional[torch.Tensor] = None
470
    query_start_loc: Optional[torch.Tensor] = None
471
472
    block_tables: Optional[torch.Tensor] = None

473
474
475
476
    # used for GPU in-place advance_step
    seq_lens_tensor: Optional[torch.Tensor] = None
    block_table_bound: Optional[torch.Tensor] = None

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    # An example for paged_kv_indices, paged_kv_indptr:
    # request 1, page indices [0, 5, 8]
    # request 2, page indices [1, 6, 7]
    # request 3, page indices [3, 4]
    # paged_kv_indices is a concatenation of page indices of all requests:
    # [0, 5, 8, 1, 6, 7, 3, 4]
    # paged_kv_indptr is used to index into paged_kv_indices:
    # [0, 3, 6, 8]
    # The indptr of the paged kv cache, shape: [batch_size + 1]
    paged_kv_indptr: Optional[torch.Tensor] = None
    # The page indices of the paged kv cache
    paged_kv_indices: Optional[torch.Tensor] = None
    # The number of entries in the last page of each request in
    # the paged kv cache, shape: [batch_size]
    paged_kv_last_page_len: Optional[torch.Tensor] = None
    # The number of query/output heads
    num_qo_heads: Optional[int] = None
    # The number of key/value heads
    num_kv_heads: Optional[int] = None
    # The dimension of the attention heads
    head_dim: Optional[int] = None
    # Block size of vllm
    page_size: Optional[int] = None
    # The data type of the paged kv cache
    data_type: torch.dtype = None
502
503
    # The data type of the query
    q_data_type: torch.dtype = None
504
505
    # FlashInfer 0.2 encourages passing host tensors
    device: torch.device = torch.device("cpu")
506
    is_profile_run: bool = False
507

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    # The FlashInfer backend currently supports only models in which all layers
    # share the same following hyperparameters:

    # The left (inclusive) window size for the attention window, when
    # set to `-1`, the window size will be set to the full length of
    # the sequence. Defaults to `-1`.
    window_left: int = -1
    # The attention logits soft capping value (used in Gemini, Grok and
    # Gemma-2, etc.), if not provided, will be set to `0`. If greater
    # than 0, the logits will be capped according to formula:
    # $$\texttt{logits\_soft\_cap} \times
    # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
    # where $x$ is the input logits.
    logits_soft_cap: Optional[float] = None
    # The scale used in softmax, if not provided, will be set to
    # `1.0 / sqrt(head_dim)`.
    sm_scale: Optional[float] = None

526
527
528
529
530
531
532
533
    def __post_init__(self):
        # Refer to
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
        if self.head_dim is not None and self.head_dim \
                not in supported_head_sizes:
            raise ValueError(
                f"Only {supported_head_sizes} are supported for head_dim,",
534
                f" received {self.head_dim}.")
535

536
537
538
539
540
541
    def begin_forward(self):
        if self.num_prefill_tokens > 0:
            if self.paged_kv_indices is None:
                return

            assert self.prefill_wrapper is not None
542
            assert self.query_start_loc is not None
543
544
545
            assert self.paged_kv_indices is not None
            assert self.paged_kv_indptr is not None
            assert self.paged_kv_last_page_len is not None
546
547
            assert self.block_table_bound is not None
            assert self.seq_lens_tensor is not None
548
            self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
549
550
            batch_size = self.query_start_loc.shape[0] - 1
            assert batch_size >= 0
551
552
553
554
            # We will use flash attention for profiling to
            # determine the number of blocks. Therefore,
            # we don't need to prepare the input for flashinfer for profile run.
            if not self.is_profile_run:
555
                self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
556
557
                self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                    self.device)
558
559
                self.block_table_bound = self.block_table_bound.to(self.device)
                self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
560
                self.paged_kv_indices = self.paged_kv_indices.to(self.device)
561
                self.prefill_wrapper.plan(
562
563
564
565
                    self.query_start_loc,
                    self.paged_kv_indptr[:self.num_prefills + 1],
                    self.paged_kv_indices,
                    self.paged_kv_last_page_len[:self.num_prefills],
566
567
568
569
570
571
572
573
574
575
                    self.num_qo_heads,
                    self.num_kv_heads,
                    self.head_dim,
                    self.page_size,
                    causal=True,
                    sm_scale=self.sm_scale,
                    window_left=self.window_left,
                    logits_soft_cap=self.logits_soft_cap,
                    q_data_type=self.q_data_type,
                    kv_data_type=self.data_type)
576
        if self.num_decode_tokens > 0:
577
578
579
580
581
582
583
584
585
586
587
588
            assert self.paged_kv_indices is not None
            assert self.paged_kv_indptr is not None
            assert self.paged_kv_last_page_len is not None
            self.paged_kv_indices = self.paged_kv_indices.to(self.device)
            self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
            self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                self.device)
            # handle model warmup path
            if self.block_table_bound is not None:
                self.block_table_bound = self.block_table_bound.to(self.device)
            if self.seq_lens_tensor is not None:
                self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
589
590

            assert self.decode_wrapper is not None
591
            self.decode_wrapper.plan(
592
                self.paged_kv_indptr[self.num_prefills:],
593
                self.paged_kv_indices,
594
                self.paged_kv_last_page_len[self.num_prefills:],
595
596
597
598
599
600
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
                self.page_size,
                # Disable flashinfer's pos encoding and use vllm's rope.
                pos_encoding_mode="NONE",
601
602
603
                window_left=self.window_left,
                logits_soft_cap=self.logits_soft_cap,
                sm_scale=self.sm_scale,
604
                # kv-cache data type.
605
                kv_data_type=self.data_type,
606
607
                # query data type.
                q_data_type=self.q_data_type)
608
609
610
611
612
613

    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
        if skip_fields is None:
            skip_fields = set()
614
        # We need to skip the prefill/decode_wrapper field since it cannot be
615
        # broadcasted with nccl when TP is enabled.
616
        skip_fields.add('prefill_wrapper')
617
618
619
        skip_fields.add('decode_wrapper')
        return super().asdict_zerocopy(skip_fields)

620
621
    @property
    def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
622
623
624
        if self.num_prefills == 0:
            return None
        return self
625
626
627

    @property
    def decode_metadata(self) -> Optional["FlashInferMetadata"]:
628
        if self.num_decode_tokens == 0:
629
630
631
            return None
        return self

632
633
634
635
636
637
638
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
                     sampled_token_ids: Optional[torch.Tensor],
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
639
640
641
642
        """
        Update metadata in-place to advance one decode step.
        """

643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        if turn_prefills_into_decodes:
            # When Multi-Step is enabled with Chunked-Prefill, prefills and
            # decodes are scheduled together. In the first step, all the
            # prefills turn into decodes. This update reflects that
            # conversion.
            assert self.num_decode_tokens + self.num_prefills == num_seqs
            # Flashinfer doesn't support speculative decoding + chunked-prefill
            # + multi-step scheduling yet.
            assert self.decode_query_len == 1
            self.num_decode_tokens += self.num_prefills
            self.num_prefills = 0
            self.num_prefill_tokens = 0
            self.max_prefill_seq_len = 0
            self.max_query_len = 1

            self.slot_mapping = self.slot_mapping[:num_seqs]
        else:
            assert self.seq_lens_tensor is not None
661

662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
        assert num_seqs > 0
        assert num_queries > 0
        assert model_input.attn_metadata is not None
        assert sampled_token_ids is not None

        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()

        # Update GPU tensors
        ops.advance_step_flashinfer(
            num_seqs=num_seqs,
            num_queries=num_queries,
            block_size=block_size,
            input_tokens=model_input.input_tokens,
            sampled_token_ids=model_input.input_tokens,
            input_positions=model_input.input_positions,
            seq_lens=self.seq_lens_tensor,
            slot_mapping=self.slot_mapping,
            block_tables=self.block_tables,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            block_table_bound=self.block_table_bound)

692

693
694
695
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
696
697
698
699
700
701
702

        self.input_builder = input_builder
        self.runner = input_builder.runner

        self.sliding_window = input_builder.sliding_window
        self.block_size = input_builder.block_size

703
704
705
        # Global hyperparameters shared by all attention layers
        self.global_hyperparameters: Optional[PerLayerParameters] = None

706
        self.vllm_config = self.runner.vllm_config
707

708
    def prepare(self):
709
710
711
712
713
        self.slot_mapping: List[int] = []
        self.prefill_seq_lens: List[int] = []
        self.context_lens: List[int] = []
        self.block_tables: List[List[int]] = []
        self.curr_seq_lens: List[int] = []
714
715
716
        self.multimodal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
        self.num_prefills = 0
        self.num_prefill_tokens = 0
        self.num_decode_tokens = 0

        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        self.paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        self.paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        self.paged_kv_last_page_len: List[int] = []
737
        self.total_blocks = 0
738
739
        self.is_profile_run: bool = False

740
741
742
743
744
745
746
747
748
749
750
751
752
753
        if self.global_hyperparameters is None:
            # Infer global hyperparameters, since currently we only support
            # models in which all layers share the same values for the
            # following hyperparameters:
            # - `window_left`
            # - `logits_soft_cap`
            # - `sm_scale`
            inferred_params = infer_global_hyperparameters(
                get_per_layer_parameters(self.vllm_config))
            self.global_hyperparameters = inferred_params
            self.window_left = inferred_params.window_left
            self.logits_soft_cap = inferred_params.logits_soft_cap
            self.sm_scale = inferred_params.sm_scale

754
755
756
    def _add_seq_group(
            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
            chunked_prefill_enabled: bool):
757
758
759
760
761
        """Add a sequence group to the metadata. Specifically update/append
        1. context length.
        2. block table.
        3. slot mapping.
        """
762
763
764
        is_prompt = inter_data.is_prompt
        block_tables = inter_data.block_tables
        computed_block_nums = inter_data.computed_block_nums
765
766
767

        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
             curr_sliding_window_block) in zip(
768
769
770
771
                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
                 inter_data.orig_seq_lens, inter_data.seq_lens,
                 inter_data.query_lens, inter_data.context_lens,
                 inter_data.curr_sliding_window_blocks):
772
773
            self.context_lens.append(context_len)
            if is_prompt:
774
775
776
777
778
                mm_maps = inter_data.multi_modal_placeholder_maps
                if mm_maps:
                    for modality, placeholders in mm_maps.items():
                        self.multimodal_placeholder_maps[modality].extend(
                            placeholders)
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
                self.num_prefills += 1
                self.num_prefill_tokens += token_len
                self.prefill_seq_lens.append(seq_len)
            else:
                assert query_len == 1, (
                    "seq_len: {}, context_len: {}, query_len: {}".format(
                        seq_len, context_len, query_len))
                self.num_decode_tokens += query_len
                self.curr_seq_lens.append(curr_seq_len)

            # Compute block table.
            # TODO(sang): Combine chunked prefill and prefix caching by
            # only allowing multiple of block_size chunk size.
            # NOTE: This only works for oooooooxxx style attention.
            block_table = []
794
            if inter_data.prefix_cache_hit:
795
796
797
798
799
800
801
802
803
                block_table = computed_block_nums
            elif ((chunked_prefill_enabled or not is_prompt)
                  and block_tables is not None):
                block_table = block_tables[seq_id][-curr_sliding_window_block:]
            self.block_tables.append(block_table)

            is_profile_run = is_block_tables_empty(block_tables)

            # Compute slot mapping.
804
805
806
            start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
                                                       context_len,
                                                       self.sliding_window)
807
808
            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                                 seq_len, context_len, start_idx,
809
                                 self.block_size, inter_data.block_tables)
810
811
812
813
814

            # It is not necessary to add paged_kv_indices, paged_kv_indptr,
            # and paged_kv_last_page_len for profile run because we will
            # create dummy inputs.
            if is_profile_run:
815
                self.is_profile_run = is_profile_run
816
817
818
                return

            block_table = block_tables[seq_id]
819
820
821
822
823
824
825
826
            self._update_paged_kv_tensors(block_table, seq_len)

    def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
        # Get the number of valid blocks based on sequence length.
        # If seq_len = 16, block_size = 16,
        # block_table_bound is 1 with 1 valid block.
        # If seq_len = 15, block_size = 16,
        # block_table_bound is 0 + 1 with 1 valid block.
827
        self.total_blocks += len(block_table)
828
829
830
831
832
833
834
835
836
837
838
        block_table_bound = seq_len // self.block_size + 1 \
                            if seq_len % self.block_size != 0 \
                            else seq_len // self.block_size
        self.paged_kv_indices.extend(block_table[:block_table_bound])
        self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
                                    block_table_bound)

        last_page_len = seq_len % self.block_size
        if last_page_len == 0:
            last_page_len = self.block_size
        self.paged_kv_last_page_len.append(last_page_len)
839

840
    def build(self, seq_lens: List[int], query_lens: List[int],
841
              cuda_graph_pad_size: int, batch_size: int):
842
843
844
845
846
847
848
849
850
        """Build attention metadata with on-device tensors.

        Args:
            seq_lens: The maybe padded sequence lengths of the input sequences.
            query_lens: The query lengths of the input sequences.
            cuda_graph_pad_size: The padding size for cuda graph.
                                 -1 if cuda graph is not used.
            batch_size: The maybe padded batch size.
        """
851
852
853
854
855
        for inter_data in self.input_builder.inter_data_list:
            self._add_seq_group(inter_data,
                                self.input_builder.chunked_prefill_enabled)

        device = self.runner.device
856
857
858
        use_captured_graph = cuda_graph_pad_size != -1

        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
859
        max_decode_seq_len = max(self.curr_seq_lens, default=0)
860
        num_decode_tokens = self.num_decode_tokens
861
        decode_query_len = max(query_lens[self.num_prefills:], default=1)
862
863
864
865

        if use_captured_graph:
            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
            self.block_tables.extend([] * cuda_graph_pad_size)
866
            num_decode_tokens = batch_size - self.num_prefill_tokens
867
868
869

            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
870
            input_block_tables = self.runner.graph_block_tables[:batch_size]
871
            max_blocks = input_block_tables.shape[1]
872
873
            for i, block_table in enumerate(self.block_tables):
                if block_table:
874
875
876
877
878
879
880
881
882
883
                    num_blocks = len(block_table)
                    if num_blocks <= max_blocks:
                        input_block_tables[i, :num_blocks] = block_table
                    else:
                        # It may be possible to have more blocks allocated due
                        # to lookahead slots of multi-step, however, they are
                        # not used anyway, so can be safely ignored.
                        input_block_tables[
                            i, :max_blocks] = block_table[:max_blocks]

884
885
            block_tables = torch.from_numpy(input_block_tables).to(
                device, non_blocking=True)
886
887
888
889
890
891
892
893
894
895
896
897
898

            last_paged_kv_indptr = self.paged_kv_indptr[-1]
            self.paged_kv_indptr.extend([last_paged_kv_indptr] *
                                        cuda_graph_pad_size)
            self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
        else:
            block_tables = make_tensor_with_pad(
                self.block_tables,
                pad=0,
                dtype=torch.int,
                device=device,
            )

899
900
901
902
903
904
905
        assert device is not None
        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
                                           self.runner.pin_memory)
        query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
                                             self.runner.pin_memory)
        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
                                               device, self.runner.pin_memory)
906
907
908
909
910
911
        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                      dtype=torch.int32,
                                      device=device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=device)
912
913
914
915
916
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            self.multimodal_placeholder_maps.items()
        }
917
918
919
920
921
922
923
924
925
926
        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=query_start_loc.dtype,
                     out=query_start_loc[1:])

        if len(self.paged_kv_indptr) > 0:
927
928
929
930
            # extend to the maximum number of blocks as returned by the
            # scheduler
            self.paged_kv_indices.extend(
                [0] * (self.total_blocks - len(self.paged_kv_indices)))
931
932
933
934
935
936
937
938
            paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
                                                   device="cpu",
                                                   dtype=torch.int)
            paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
                                                  device="cpu",
                                                  dtype=torch.int)
            paged_kv_last_page_len_tensor = torch.tensor(
                self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
939
940
941
942
            block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
                                                   1,
                                                   device="cpu",
                                                   dtype=torch.int)
943
944
945
946
        else:
            paged_kv_indices_tensor = None
            paged_kv_indptr_tensor = None
            paged_kv_last_page_len_tensor = None
947
            block_table_bound_tensor = None
948

949
950
951
952
953
954
        if self.runner.kv_cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.runner.kv_cache_dtype)
        else:
            kv_cache_dtype = get_kv_cache_torch_dtype(
                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
955

956
        return FlashInferMetadata(
957
            decode_query_len=decode_query_len,
958
959
            num_prefills=self.num_prefills,
            slot_mapping=slot_mapping_tensor,
960
            multi_modal_placeholder_index_maps=placeholder_index_maps,
961
            enable_kv_scales_calculation=False,
962
963
964
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            max_prefill_seq_len=max_prefill_seq_len,
965
            max_decode_seq_len=max_decode_seq_len,
966
967
968
969
            block_tables=block_tables,
            paged_kv_indptr=paged_kv_indptr_tensor,
            paged_kv_indices=paged_kv_indices_tensor,
            paged_kv_last_page_len=paged_kv_last_page_len_tensor,
970
971
            block_table_bound=block_table_bound_tensor,
            seq_lens_tensor=seq_lens_tensor,
972
973
974
975
976
            num_qo_heads=self.runner.model_config.get_num_attention_heads(
                self.runner.parallel_config),
            num_kv_heads=self.runner.model_config.get_num_kv_heads(
                self.runner.parallel_config),
            head_dim=self.runner.model_config.get_head_size(),
977
978
979
980
981
            page_size=self.block_size,
            seq_start_loc=seq_start_loc,
            query_start_loc=query_start_loc,
            device=device,
            data_type=kv_cache_dtype,
982
            q_data_type=self.runner.model_config.dtype,
983
            use_cuda_graph=use_captured_graph,
984
985
986
987
988
            is_profile_run=self.is_profile_run,
            window_left=self.window_left,
            logits_soft_cap=self.logits_soft_cap,
            sm_scale=self.sm_scale,
        )
989
990


991
992
993
994
995
996
997
class FlashInferImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
998
999
1000
1001
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
1002
        logits_soft_cap: Optional[float] = None,
1003
        attn_type: str = AttentionType.DECODER,
1004
        kv_sharing_target_layer_name: Optional[str] = None,
1005
        use_irope: bool = False,
1006
    ) -> None:
1007
        if kv_sharing_target_layer_name is not None:
1008
1009
            raise NotImplementedError("KV sharing is not supported in V0 "
                                      "FLASHINFER backend.")
1010
1011
1012
1013
        if use_irope:
            logger.warning_once(
                "Using irope in FlashInfer is not supported yet, it will fall"
                " back to global attention for long context.")
1014
1015
        self.num_heads = num_heads
        self.head_size = head_size
1016
        self.scale = float(scale)
1017
        self.num_kv_heads = num_kv_heads
1018
1019
1020
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
1021
1022
        self.sliding_window = ((sliding_window - 1,
                                0) if sliding_window is not None else (-1, -1))
1023
        self.kv_cache_dtype = kv_cache_dtype
1024
        self.logits_soft_cap = logits_soft_cap
1025

1026
1027
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

1028
1029
1030
1031
1032
1033
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashInferImpl")

1034
1035
    def forward(
        self,
1036
        layer: AttentionLayer,
1037
1038
1039
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
1040
        kv_cache: torch.Tensor,
1041
        attn_metadata: FlashInferMetadata,
1042
        output: Optional[torch.Tensor] = None,
1043
        output_scale: Optional[torch.Tensor] = None,
1044
    ) -> torch.Tensor:
1045

1046
1047
1048
1049
1050
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashInferImpl")

1051
        # TODO: directly write to output tensor
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        num_heads: int = self.num_heads
        head_size: int = self.head_size
        num_kv_heads: int = self.num_kv_heads
        kv_cache_dtype: str = self.kv_cache_dtype
        softmax_scale: float = self.scale
        window_size = self.sliding_window
        alibi_slopes = self.alibi_slopes
        logits_soft_cap = self.logits_soft_cap

        num_tokens, hidden_size = query.shape
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_kv_heads, head_size)
        value = value.view(-1, num_kv_heads, head_size)

        if kv_cache.numel() > 0:
            # Use the same reshape and cache kernel as flash attention.
            ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype,
1075
1076
                layer._k_scale,
                layer._v_scale,
1077
            )
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                    kv_cache_dtype)
                kv_cache = kv_cache.view(torch_dtype)

        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
                    f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
                    f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
        query = query.contiguous(
        )  # Flashinfer requires query to be contiguous
        # Query for decode. KV is not needed because it is already cached.
        # QKV for prefill.
        decode_query = query[num_prefill_tokens:]
        query = query[:num_prefill_tokens]

        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        window_left = window_size[0] if window_size is not None else -1

        prefill_output: Optional[torch.Tensor] = None
        decode_output: Optional[torch.Tensor] = None
1108
        stride_order = FlashInferBackend.get_kv_cache_stride_order()
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
        if prefill_meta := attn_metadata.prefill_metadata:
            # We will use flash attention for prefill
            # when kv_cache is not provided.
            # This happens when vllm runs the profiling to
            # determine the number of blocks.
            if kv_cache.numel() == 0:
                prefill_output = flash_attn_varlen_func(
                    q=query,
                    k=key,
                    v=value,
                    cu_seqlens_q=prefill_meta.seq_start_loc,
                    cu_seqlens_k=prefill_meta.seq_start_loc,
                    max_seqlen_q=prefill_meta.max_prefill_seq_len,
                    max_seqlen_k=prefill_meta.max_prefill_seq_len,
                    softmax_scale=softmax_scale,
                    causal=True,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                )
            else:
                assert prefill_meta is not None
                assert prefill_meta.prefill_wrapper is not None
1131
1132
1133
1134
1135
1136
1137
1138

                assert prefill_meta.prefill_wrapper._causal
                assert prefill_meta.prefill_wrapper._window_left == window_left
                assert prefill_meta.prefill_wrapper._logits_soft_cap == (
                    logits_soft_cap or 0.0)
                assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale

                prefill_output = prefill_meta.prefill_wrapper.run(
1139
                    query,
1140
                    kv_cache.permute(*stride_order),
1141
1142
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
1143
                )
1144
1145
1146
        if decode_meta := attn_metadata.decode_metadata:
            assert decode_meta is not None
            assert decode_meta.decode_wrapper is not None
1147
1148
1149
1150
1151

            assert decode_meta.decode_wrapper._window_left == window_left
            assert decode_meta.decode_wrapper._logits_soft_cap == (
                logits_soft_cap or 0.0)
            assert decode_meta.decode_wrapper._sm_scale == softmax_scale
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
            # TODO: @pavanimajety Remove this once the switch happens
            # inside flashinfer.
            if not FlashInferBackend.use_trtllm_decode_attention(
                    num_decode_tokens, attn_metadata.max_decode_seq_len,
                    kv_cache_dtype, attn_metadata.num_qo_heads,
                    attn_metadata.num_kv_heads, attn_metadata.head_dim):
                decode_output = decode_meta.decode_wrapper.run(
                    decode_query,
                    kv_cache.permute(*stride_order),
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                )
            else:
                workspace_buffer = (
                    decode_meta.decode_wrapper._int_workspace_buffer)
                assert FlashInferState.get_kv_cache_layout() == "HND"
                decode_output = trtllm_batch_decode_with_kv_cache(
                    query=decode_query,
                    kv_cache=kv_cache.permute(*stride_order),
                    workspace_buffer=workspace_buffer,
                    num_heads=num_heads,
                    num_kv_heads=num_kv_heads,
                    scale=softmax_scale,
                    block_tables=attn_metadata.block_tables,
                    seq_lens=decode_meta.seq_lens_tensor,
                    block_size=attn_metadata.page_size,
                    max_seq_len=attn_metadata.max_decode_seq_len,
                    kv_cache_dtype=kv_cache_dtype,
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float)
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198

        if prefill_output is None and decode_output is not None:
            # Decode only batch.
            output, num_tokens = decode_output, num_decode_tokens
        elif decode_output is None and prefill_output is not None:
            # Prefill only batch.
            output, num_tokens = prefill_output, num_prefill_tokens
        else:
            # Chunked prefill batch does not work with speculative decoding in
            # FlashInfer backend, so the query length for decode should be 1.
            assert prefill_output is not None
            assert decode_output is not None
            assert decode_meta is not None
            assert decode_meta.decode_query_len == 1
            decode_output = decode_output.squeeze(1)
            output = torch.cat([prefill_output, decode_output], dim=0)
        return output.view(num_tokens, hidden_size)