flashinfer.py 46.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
9

10
11
from vllm.multimodal import MultiModalPlaceholderMap

12
13
try:
    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
14
    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
15
    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
16

17
    from vllm.vllm_flash_attn import flash_attn_varlen_func
18
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
19
except ImportError:
20
21
22
23
24
    # Avoid turning these types into variables during type checking
    if not TYPE_CHECKING:
        BatchDecodeWithPagedKVCacheWrapper = None
        CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
        BatchPrefillWithPagedKVCacheWrapper = None
25
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
26

27
28
import torch

29
import vllm.envs as envs
30
31
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
32
                                              AttentionLayer,
33
34
                                              AttentionMetadata,
                                              AttentionMetadataBuilder,
35
                                              AttentionState, AttentionType)
36
37
38
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
                                           compute_slot_mapping_start_idx,
                                           is_block_tables_empty)
39
from vllm.attention.layer import Attention
40
from vllm.attention.ops.paged_attn import PagedAttention
41
from vllm.config import VllmConfig, get_layers_from_vllm_config
42
from vllm.logger import init_logger
43
44
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
                        make_tensor_with_pad)
45

46
47
logger = init_logger(__name__)

48
if TYPE_CHECKING:
49
50
    from vllm.worker.model_runner import (ModelInputForGPUBuilder,
                                          ModelInputForGPUWithSamplingMetadata)
51

52
FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD"
53

54
55
56

class FlashInferBackend(AttentionBackend):

57
58
    @staticmethod
    def get_name() -> str:
59
        return "FLASHINFER"
60

61
62
63
64
65
    @staticmethod
    def get_impl_cls() -> Type["FlashInferImpl"]:
        return FlashInferImpl

    @staticmethod
66
67
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return FlashInferMetadata
68

69
70
71
72
    @staticmethod
    def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
        return FlashInferMetadataBuilder

73
74
75
76
    @staticmethod
    def get_state_cls() -> Type["FlashInferState"]:
        return FlashInferState

77
78
79
80
81
82
83
84
85
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

86
87
88
89
90
91
92
93
    @staticmethod
    def get_kv_cache_stride_order() -> Tuple[int, ...]:
        cache_layout = FLASHINFER_KV_CACHE_LAYOUT
        assert (cache_layout in ("NHD", "HND"))
        stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
                                                                      2, 4)
        return stride_order

94
95
96
97
    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
98
        src_to_dst: torch.Tensor,
99
    ) -> None:
100
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
101
102
103
104

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
105
        src_to_dists: torch.Tensor,
106
    ) -> None:
107
        PagedAttention.copy_blocks(kv_caches, src_to_dists)
108
109
110
111
112

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [64, 128, 256]

113
114
115
116
117
118
119
120
121
    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@dataclass
class PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
    the same values for the following hyperparameters.
    """

    window_left: int
    logits_soft_cap: Optional[float]
    sm_scale: float


def get_per_layer_parameters(
        vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
    """
    Scan all attention layers and determine some hyperparameters
    to use during `plan`.
    """

142
    layers = get_layers_from_vllm_config(vllm_config, Attention)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    per_layer_params: Dict[str, PerLayerParameters] = {}

    for key, layer in layers.items():
        impl = layer.impl
        assert isinstance(impl, FlashInferImpl)

        # Infer hyperparameters from the attention layer
        window_size = impl.sliding_window
        window_left = window_size[0] if window_size is not None else -1
        logits_soft_cap = impl.logits_soft_cap
        sm_scale = impl.scale

        per_layer_params[key] = PerLayerParameters(window_left,
                                                   logits_soft_cap, sm_scale)

    return per_layer_params


def infer_global_hyperparameters(
        per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
    the same values for the following hyperparameters:
    - `window_left`
    - `logits_soft_cap`
    - `sm_scale`

    So this function asserts that all layers share the same values for these
    hyperparameters and returns the global values.
    """

    assert len(per_layer_params) > 0, "No attention layers found in the model."

    param_sets = list(per_layer_params.values())
    global_params = param_sets[0]
    for params in param_sets:
        assert params == global_params, (
            "FlashInfer backend currently only supports models in which all "
            "layers share the same values for the following hyperparameters: "
            "`window_left`, `logits_soft_cap`, `sm_scale`.")

    return global_params


187
188
189
190
191
192
193
194
195
class FlashInferState(AttentionState):

    def __init__(self, runner):
        self.runner = runner
        self._is_graph_capturing = False
        self._workspace_buffer = None
        self._decode_wrapper = None
        self._prefill_wrapper = None

196
197
198
        # Global hyperparameters shared by all attention layers
        self.global_hyperparameters: Optional[PerLayerParameters] = None

199
        self.vllm_config = self.runner.vllm_config
200
        self._kv_cache_layout = None
201

202
203
204
205
206
207
208
209
    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
            self._workspace_buffer = torch.empty(
                FLASHINFER_WORKSPACE_BUFFER_SIZE,
                dtype=torch.uint8,
                device=self.runner.device)
        return self._workspace_buffer

210
211
212
213
214
    def get_kv_cache_layout(self):
        if self._kv_cache_layout is None:
            self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
        return self._kv_cache_layout

215
216
217
    def _get_prefill_wrapper(self):
        if self._prefill_wrapper is None:
            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
218
                self._get_workspace_buffer(), self.get_kv_cache_layout())
219
220
221
222
223
224
225
226
        return self._prefill_wrapper

    def _get_decode_wrapper(self):
        if self._decode_wrapper is None:
            num_qo_heads = (self.runner.model_config.get_num_attention_heads(
                self.runner.parallel_config))
            num_kv_heads = self.runner.model_config.get_num_kv_heads(
                self.runner.parallel_config)
227
228
            use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
                num_qo_heads // num_kv_heads > 4)
229
230
            self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
                self._get_workspace_buffer(),
231
                self.get_kv_cache_layout(),
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                use_tensor_cores=use_tensor_cores)
        return self._decode_wrapper

    @contextmanager
    def graph_capture(self, max_batch_size: int):
        self._is_graph_capturing = True
        self._graph_decode_wrapper = None
        self._graph_slot_mapping = torch.full((max_batch_size, ),
                                              PAD_SLOT_ID,
                                              dtype=torch.long,
                                              device=self.runner.device)
        self._graph_seq_lens = torch.ones(max_batch_size,
                                          dtype=torch.int32,
                                          device=self.runner.device)
        self._graph_block_tables = torch.from_numpy(
            self.runner.graph_block_tables).to(device=self.runner.device)
        self._graph_decode_workspace_buffer = self._get_workspace_buffer()
        self._graph_indices_buffer = torch.empty(
            max_batch_size * self.runner.cache_config.num_gpu_blocks,
            dtype=torch.int32,
            device=self.runner.device)
        self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
                                                dtype=torch.int32,
                                                device=self.runner.device)
        self._graph_last_page_len_buffer = torch.empty(
            max_batch_size, dtype=torch.int32, device=self.runner.device)
        yield
        self._is_graph_capturing = False
        del self._graph_slot_mapping
        del self._graph_seq_lens
        del self._graph_block_tables
        del self._graph_decode_workspace_buffer
        del self._graph_indices_buffer
        del self._graph_indptr_buffer
        del self._graph_last_page_len_buffer
        del self._graph_decode_wrapper

    def graph_clone(self, batch_size: int):
        assert self._is_graph_capturing
        state = self.__class__(self.runner)
        state._workspace_buffer = self._graph_decode_workspace_buffer
        state._decode_wrapper = self._graph_decode_wrapper
        state._prefill_wrapper = self._get_prefill_wrapper()
        return state

277
278
    def graph_capture_get_metadata_for_batch(
            self, batch_size: int, is_encoder_decoder_model: bool = False):
279
280
281
282
283
284
285
286
        assert self._is_graph_capturing
        _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
        _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]

        num_qo_heads = (self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config))
        num_kv_heads = self.runner.model_config.get_num_kv_heads(
            self.runner.parallel_config)
287
288
        use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
            num_qo_heads // num_kv_heads > 4)
289
290
291
        self._graph_decode_wrapper = \
            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
            self._graph_decode_workspace_buffer, _indptr_buffer,
292
293
            self._graph_indices_buffer, _last_page_len_buffer,
            self.get_kv_cache_layout(),
294
            use_tensor_cores)
295
296
297
298
299
300
        if self.runner.kv_cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.runner.kv_cache_dtype)
        else:
            kv_cache_dtype = get_kv_cache_torch_dtype(
                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
301
302
303
304
305
306
307
308
309
310
311
312
313
314

        paged_kv_indptr_tensor_host = torch.arange(0,
                                                   batch_size + 1,
                                                   dtype=torch.int32)
        paged_kv_indices_tensor_host = torch.arange(0,
                                                    batch_size,
                                                    dtype=torch.int32)
        paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
                                                        self.runner.block_size,
                                                        dtype=torch.int32)
        query_start_loc_host = torch.arange(0,
                                            batch_size + 1,
                                            dtype=torch.int32)

315
316
317
        global_params = infer_global_hyperparameters(
            get_per_layer_parameters(self.vllm_config))

318
319
320
        attn_metadata = self.runner.attn_backend.make_metadata(
            num_prefills=0,
            slot_mapping=self._graph_slot_mapping[:batch_size],
321
            multi_modal_placeholder_index_maps=None,
322
            enable_kv_scales_calculation=False,
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            max_prefill_seq_len=0,
            block_tables=self._graph_block_tables,
            paged_kv_indptr=paged_kv_indptr_tensor_host,
            paged_kv_indices=paged_kv_indices_tensor_host,
            paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim=self.runner.model_config.get_head_size(),
            page_size=self.runner.block_size,
            seq_start_loc=None,
            query_start_loc=query_start_loc_host,
            device=self.runner.device,
            data_type=kv_cache_dtype,
338
            q_data_type=self.runner.model_config.dtype,
339
340
            use_cuda_graph=True,
            decode_wrapper=self._graph_decode_wrapper,
341
342
343
            prefill_wrapper=None,
            **dataclasses.asdict(global_params),
        )
344
345
346
        attn_metadata.begin_forward()
        return attn_metadata

347
348
349
    def get_graph_input_buffers(self,
                                attn_metadata,
                                is_encoder_decoder_model: bool = False):
350
351
352
353
        return {
            "slot_mapping": attn_metadata.slot_mapping,
        }

354
355
356
357
    def prepare_graph_input_buffers(self,
                                    input_buffers,
                                    attn_metadata,
                                    is_encoder_decoder_model: bool = False):
358
359
360
361
362
        return

    def begin_forward(self, model_input):
        assert not self._is_graph_capturing
        state = self
363
364
365
366
367
368
        use_cuda_graph = model_input.attn_metadata.use_cuda_graph
        is_decode = model_input.attn_metadata.num_prefills == 0
        # In case of multistep chunked-prefill, there might be prefill requests
        # scheduled while CUDA graph mode is enabled. We don't run graph in that
        # case.
        if use_cuda_graph and is_decode:
369
370
371
372
373
374
375
376
377
378
379
            if model_input.inputs_embeds is None:
                batch_size = model_input.input_tokens.shape[0]
                state = (
                    self.runner.graph_runners[model_input.virtual_engine][(
                        batch_size, False)].attn_state)
            else:
                batch_size = model_input.inputs_embeds.shape[0]
                state = (
                    self.runner.graph_runners[model_input.virtual_engine][(
                        batch_size, True)].attn_state)

380
381
382
383
384
385
        model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
        )
        model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
        model_input.attn_metadata.begin_forward()


386
@dataclass
387
388
389
390
class FlashInferMetadata(AttentionMetadata):
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
391
392
393
394
395
    # Number of query tokens for each request in the batch.
    # Currently, we require that all requests have the same number of query
    # tokens during the decoding phase. When speculavie decoding is enabled,
    # decode_query_len might be greater than 1. In all other cases, it is 1.
    decode_query_len: Optional[int] = 1
396

397
    use_cuda_graph: bool = True
398

399
    prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
400
401
    decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None

402
    # Metadata for the prefill stage
403
    seq_start_loc: Optional[torch.Tensor] = None
404
    query_start_loc: Optional[torch.Tensor] = None
405
406
    block_tables: Optional[torch.Tensor] = None

407
408
409
410
    # used for GPU in-place advance_step
    seq_lens_tensor: Optional[torch.Tensor] = None
    block_table_bound: Optional[torch.Tensor] = None

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    # An example for paged_kv_indices, paged_kv_indptr:
    # request 1, page indices [0, 5, 8]
    # request 2, page indices [1, 6, 7]
    # request 3, page indices [3, 4]
    # paged_kv_indices is a concatenation of page indices of all requests:
    # [0, 5, 8, 1, 6, 7, 3, 4]
    # paged_kv_indptr is used to index into paged_kv_indices:
    # [0, 3, 6, 8]
    # The indptr of the paged kv cache, shape: [batch_size + 1]
    paged_kv_indptr: Optional[torch.Tensor] = None
    # The page indices of the paged kv cache
    paged_kv_indices: Optional[torch.Tensor] = None
    # The number of entries in the last page of each request in
    # the paged kv cache, shape: [batch_size]
    paged_kv_last_page_len: Optional[torch.Tensor] = None
    # The number of query/output heads
    num_qo_heads: Optional[int] = None
    # The number of key/value heads
    num_kv_heads: Optional[int] = None
    # The dimension of the attention heads
    head_dim: Optional[int] = None
    # Block size of vllm
    page_size: Optional[int] = None
    # The data type of the paged kv cache
    data_type: torch.dtype = None
436
437
    # The data type of the query
    q_data_type: torch.dtype = None
438
439
    # FlashInfer 0.2 encourages passing host tensors
    device: torch.device = torch.device("cpu")
440
    is_profile_run: bool = False
441

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    # The FlashInfer backend currently supports only models in which all layers
    # share the same following hyperparameters:

    # The left (inclusive) window size for the attention window, when
    # set to `-1`, the window size will be set to the full length of
    # the sequence. Defaults to `-1`.
    window_left: int = -1
    # The attention logits soft capping value (used in Gemini, Grok and
    # Gemma-2, etc.), if not provided, will be set to `0`. If greater
    # than 0, the logits will be capped according to formula:
    # $$\texttt{logits\_soft\_cap} \times
    # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
    # where $x$ is the input logits.
    logits_soft_cap: Optional[float] = None
    # The scale used in softmax, if not provided, will be set to
    # `1.0 / sqrt(head_dim)`.
    sm_scale: Optional[float] = None

460
461
462
463
464
465
466
467
    def __post_init__(self):
        # Refer to
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
        if self.head_dim is not None and self.head_dim \
                not in supported_head_sizes:
            raise ValueError(
                f"Only {supported_head_sizes} are supported for head_dim,",
468
                f" received {self.head_dim}.")
469

470
471
472
473
474
475
    def begin_forward(self):
        if self.num_prefill_tokens > 0:
            if self.paged_kv_indices is None:
                return

            assert self.prefill_wrapper is not None
476
            assert self.query_start_loc is not None
477
478
479
            assert self.paged_kv_indices is not None
            assert self.paged_kv_indptr is not None
            assert self.paged_kv_last_page_len is not None
480
481
            assert self.block_table_bound is not None
            assert self.seq_lens_tensor is not None
482
            self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
483
484
            batch_size = self.query_start_loc.shape[0] - 1
            assert batch_size >= 0
485
486
487
488
            # We will use flash attention for profiling to
            # determine the number of blocks. Therefore,
            # we don't need to prepare the input for flashinfer for profile run.
            if not self.is_profile_run:
489
                self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
490
491
                self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                    self.device)
492
493
                self.block_table_bound = self.block_table_bound.to(self.device)
                self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
494
                self.paged_kv_indices = self.paged_kv_indices.to(self.device)
495
                self.prefill_wrapper.plan(
496
497
498
499
                    self.query_start_loc,
                    self.paged_kv_indptr[:self.num_prefills + 1],
                    self.paged_kv_indices,
                    self.paged_kv_last_page_len[:self.num_prefills],
500
501
502
503
504
505
506
507
508
509
                    self.num_qo_heads,
                    self.num_kv_heads,
                    self.head_dim,
                    self.page_size,
                    causal=True,
                    sm_scale=self.sm_scale,
                    window_left=self.window_left,
                    logits_soft_cap=self.logits_soft_cap,
                    q_data_type=self.q_data_type,
                    kv_data_type=self.data_type)
510
        if self.num_decode_tokens > 0:
511
512
513
514
515
516
517
518
519
520
521
522
            assert self.paged_kv_indices is not None
            assert self.paged_kv_indptr is not None
            assert self.paged_kv_last_page_len is not None
            self.paged_kv_indices = self.paged_kv_indices.to(self.device)
            self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
            self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                self.device)
            # handle model warmup path
            if self.block_table_bound is not None:
                self.block_table_bound = self.block_table_bound.to(self.device)
            if self.seq_lens_tensor is not None:
                self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
523
524

            assert self.decode_wrapper is not None
525
            self.decode_wrapper.plan(
526
                self.paged_kv_indptr[self.num_prefills:],
527
                self.paged_kv_indices,
528
                self.paged_kv_last_page_len[self.num_prefills:],
529
530
531
532
533
534
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
                self.page_size,
                # Disable flashinfer's pos encoding and use vllm's rope.
                pos_encoding_mode="NONE",
535
536
537
                window_left=self.window_left,
                logits_soft_cap=self.logits_soft_cap,
                sm_scale=self.sm_scale,
538
                # kv-cache data type.
539
                kv_data_type=self.data_type,
540
541
                # query data type.
                q_data_type=self.q_data_type)
542
543
544
545
546
547

    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
        if skip_fields is None:
            skip_fields = set()
548
        # We need to skip the prefill/decode_wrapper field since it cannot be
549
        # broadcasted with nccl when TP is enabled.
550
        skip_fields.add('prefill_wrapper')
551
552
553
        skip_fields.add('decode_wrapper')
        return super().asdict_zerocopy(skip_fields)

554
555
    @property
    def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
556
557
558
        if self.num_prefills == 0:
            return None
        return self
559
560
561

    @property
    def decode_metadata(self) -> Optional["FlashInferMetadata"]:
562
        if self.num_decode_tokens == 0:
563
564
565
            return None
        return self

566
567
568
569
570
571
572
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
                     sampled_token_ids: Optional[torch.Tensor],
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
573
574
575
576
        """
        Update metadata in-place to advance one decode step.
        """

577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        if turn_prefills_into_decodes:
            # When Multi-Step is enabled with Chunked-Prefill, prefills and
            # decodes are scheduled together. In the first step, all the
            # prefills turn into decodes. This update reflects that
            # conversion.
            assert self.num_decode_tokens + self.num_prefills == num_seqs
            # Flashinfer doesn't support speculative decoding + chunked-prefill
            # + multi-step scheduling yet.
            assert self.decode_query_len == 1
            self.num_decode_tokens += self.num_prefills
            self.num_prefills = 0
            self.num_prefill_tokens = 0
            self.max_prefill_seq_len = 0
            self.max_query_len = 1

            self.slot_mapping = self.slot_mapping[:num_seqs]
        else:
            assert self.seq_lens_tensor is not None
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        assert num_seqs > 0
        assert num_queries > 0
        assert model_input.attn_metadata is not None
        assert sampled_token_ids is not None

        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()

        # Update GPU tensors
        ops.advance_step_flashinfer(
            num_seqs=num_seqs,
            num_queries=num_queries,
            block_size=block_size,
            input_tokens=model_input.input_tokens,
            sampled_token_ids=model_input.input_tokens,
            input_positions=model_input.input_positions,
            seq_lens=self.seq_lens_tensor,
            slot_mapping=self.slot_mapping,
            block_tables=self.block_tables,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            block_table_bound=self.block_table_bound)

626

627
628
629
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
630
631
632
633
634
635
636

        self.input_builder = input_builder
        self.runner = input_builder.runner

        self.sliding_window = input_builder.sliding_window
        self.block_size = input_builder.block_size

637
638
639
        # Global hyperparameters shared by all attention layers
        self.global_hyperparameters: Optional[PerLayerParameters] = None

640
        self.vllm_config = self.runner.vllm_config
641

642
    def prepare(self):
643
644
645
646
647
        self.slot_mapping: List[int] = []
        self.prefill_seq_lens: List[int] = []
        self.context_lens: List[int] = []
        self.block_tables: List[List[int]] = []
        self.curr_seq_lens: List[int] = []
648
649
650
        self.multimodal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
        self.num_prefills = 0
        self.num_prefill_tokens = 0
        self.num_decode_tokens = 0

        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        self.paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        self.paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        self.paged_kv_last_page_len: List[int] = []
671
        self.total_blocks = 0
672
673
        self.is_profile_run: bool = False

674
675
676
677
678
679
680
681
682
683
684
685
686
687
        if self.global_hyperparameters is None:
            # Infer global hyperparameters, since currently we only support
            # models in which all layers share the same values for the
            # following hyperparameters:
            # - `window_left`
            # - `logits_soft_cap`
            # - `sm_scale`
            inferred_params = infer_global_hyperparameters(
                get_per_layer_parameters(self.vllm_config))
            self.global_hyperparameters = inferred_params
            self.window_left = inferred_params.window_left
            self.logits_soft_cap = inferred_params.logits_soft_cap
            self.sm_scale = inferred_params.sm_scale

688
689
690
    def _add_seq_group(
            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
            chunked_prefill_enabled: bool):
691
692
693
694
695
        """Add a sequence group to the metadata. Specifically update/append
        1. context length.
        2. block table.
        3. slot mapping.
        """
696
697
698
        is_prompt = inter_data.is_prompt
        block_tables = inter_data.block_tables
        computed_block_nums = inter_data.computed_block_nums
699
700
701

        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
             curr_sliding_window_block) in zip(
702
703
704
705
                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
                 inter_data.orig_seq_lens, inter_data.seq_lens,
                 inter_data.query_lens, inter_data.context_lens,
                 inter_data.curr_sliding_window_blocks):
706
707
            self.context_lens.append(context_len)
            if is_prompt:
708
709
710
711
712
                mm_maps = inter_data.multi_modal_placeholder_maps
                if mm_maps:
                    for modality, placeholders in mm_maps.items():
                        self.multimodal_placeholder_maps[modality].extend(
                            placeholders)
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
                self.num_prefills += 1
                self.num_prefill_tokens += token_len
                self.prefill_seq_lens.append(seq_len)
            else:
                assert query_len == 1, (
                    "seq_len: {}, context_len: {}, query_len: {}".format(
                        seq_len, context_len, query_len))
                self.num_decode_tokens += query_len
                self.curr_seq_lens.append(curr_seq_len)

            # Compute block table.
            # TODO(sang): Combine chunked prefill and prefix caching by
            # only allowing multiple of block_size chunk size.
            # NOTE: This only works for oooooooxxx style attention.
            block_table = []
728
            if inter_data.prefix_cache_hit:
729
730
731
732
733
734
735
736
737
                block_table = computed_block_nums
            elif ((chunked_prefill_enabled or not is_prompt)
                  and block_tables is not None):
                block_table = block_tables[seq_id][-curr_sliding_window_block:]
            self.block_tables.append(block_table)

            is_profile_run = is_block_tables_empty(block_tables)

            # Compute slot mapping.
738
739
740
            start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
                                                       context_len,
                                                       self.sliding_window)
741
742
            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                                 seq_len, context_len, start_idx,
743
                                 self.block_size, inter_data.block_tables)
744
745
746
747
748

            # It is not necessary to add paged_kv_indices, paged_kv_indptr,
            # and paged_kv_last_page_len for profile run because we will
            # create dummy inputs.
            if is_profile_run:
749
                self.is_profile_run = is_profile_run
750
751
752
                return

            block_table = block_tables[seq_id]
753
754
755
756
757
758
759
760
            self._update_paged_kv_tensors(block_table, seq_len)

    def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
        # Get the number of valid blocks based on sequence length.
        # If seq_len = 16, block_size = 16,
        # block_table_bound is 1 with 1 valid block.
        # If seq_len = 15, block_size = 16,
        # block_table_bound is 0 + 1 with 1 valid block.
761
        self.total_blocks += len(block_table)
762
763
764
765
766
767
768
769
770
771
772
        block_table_bound = seq_len // self.block_size + 1 \
                            if seq_len % self.block_size != 0 \
                            else seq_len // self.block_size
        self.paged_kv_indices.extend(block_table[:block_table_bound])
        self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
                                    block_table_bound)

        last_page_len = seq_len % self.block_size
        if last_page_len == 0:
            last_page_len = self.block_size
        self.paged_kv_last_page_len.append(last_page_len)
773

774
    def build(self, seq_lens: List[int], query_lens: List[int],
775
              cuda_graph_pad_size: int, batch_size: int):
776
777
778
779
780
781
782
783
784
        """Build attention metadata with on-device tensors.

        Args:
            seq_lens: The maybe padded sequence lengths of the input sequences.
            query_lens: The query lengths of the input sequences.
            cuda_graph_pad_size: The padding size for cuda graph.
                                 -1 if cuda graph is not used.
            batch_size: The maybe padded batch size.
        """
785
786
787
788
789
        for inter_data in self.input_builder.inter_data_list:
            self._add_seq_group(inter_data,
                                self.input_builder.chunked_prefill_enabled)

        device = self.runner.device
790
791
792
793
        use_captured_graph = cuda_graph_pad_size != -1

        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
        num_decode_tokens = self.num_decode_tokens
794
        decode_query_len = max(query_lens[self.num_prefills:], default=1)
795
796
797
798

        if use_captured_graph:
            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
            self.block_tables.extend([] * cuda_graph_pad_size)
799
            num_decode_tokens = batch_size - self.num_prefill_tokens
800
801
802

            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
803
            input_block_tables = self.runner.graph_block_tables[:batch_size]
804
            max_blocks = input_block_tables.shape[1]
805
806
            for i, block_table in enumerate(self.block_tables):
                if block_table:
807
808
809
810
811
812
813
814
815
816
                    num_blocks = len(block_table)
                    if num_blocks <= max_blocks:
                        input_block_tables[i, :num_blocks] = block_table
                    else:
                        # It may be possible to have more blocks allocated due
                        # to lookahead slots of multi-step, however, they are
                        # not used anyway, so can be safely ignored.
                        input_block_tables[
                            i, :max_blocks] = block_table[:max_blocks]

817
818
            block_tables = torch.from_numpy(input_block_tables).to(
                device, non_blocking=True)
819
820
821
822
823
824
825
826
827
828
829
830
831

            last_paged_kv_indptr = self.paged_kv_indptr[-1]
            self.paged_kv_indptr.extend([last_paged_kv_indptr] *
                                        cuda_graph_pad_size)
            self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
        else:
            block_tables = make_tensor_with_pad(
                self.block_tables,
                pad=0,
                dtype=torch.int,
                device=device,
            )

832
833
834
835
836
837
838
        assert device is not None
        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
                                           self.runner.pin_memory)
        query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
                                             self.runner.pin_memory)
        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
                                               device, self.runner.pin_memory)
839
840
841
842
843
844
        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                      dtype=torch.int32,
                                      device=device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=device)
845
846
847
848
849
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            self.multimodal_placeholder_maps.items()
        }
850
851
852
853
854
855
856
857
858
859
        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=query_start_loc.dtype,
                     out=query_start_loc[1:])

        if len(self.paged_kv_indptr) > 0:
860
861
862
863
            # extend to the maximum number of blocks as returned by the
            # scheduler
            self.paged_kv_indices.extend(
                [0] * (self.total_blocks - len(self.paged_kv_indices)))
864
865
866
867
868
869
870
871
            paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
                                                   device="cpu",
                                                   dtype=torch.int)
            paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
                                                  device="cpu",
                                                  dtype=torch.int)
            paged_kv_last_page_len_tensor = torch.tensor(
                self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
872
873
874
875
            block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
                                                   1,
                                                   device="cpu",
                                                   dtype=torch.int)
876
877
878
879
        else:
            paged_kv_indices_tensor = None
            paged_kv_indptr_tensor = None
            paged_kv_last_page_len_tensor = None
880
            block_table_bound_tensor = None
881

882
883
884
885
886
887
        if self.runner.kv_cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.runner.kv_cache_dtype)
        else:
            kv_cache_dtype = get_kv_cache_torch_dtype(
                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
888

889
        return FlashInferMetadata(
890
            decode_query_len=decode_query_len,
891
892
            num_prefills=self.num_prefills,
            slot_mapping=slot_mapping_tensor,
893
            multi_modal_placeholder_index_maps=placeholder_index_maps,
894
            enable_kv_scales_calculation=False,
895
896
897
898
899
900
901
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            max_prefill_seq_len=max_prefill_seq_len,
            block_tables=block_tables,
            paged_kv_indptr=paged_kv_indptr_tensor,
            paged_kv_indices=paged_kv_indices_tensor,
            paged_kv_last_page_len=paged_kv_last_page_len_tensor,
902
903
            block_table_bound=block_table_bound_tensor,
            seq_lens_tensor=seq_lens_tensor,
904
905
906
907
908
            num_qo_heads=self.runner.model_config.get_num_attention_heads(
                self.runner.parallel_config),
            num_kv_heads=self.runner.model_config.get_num_kv_heads(
                self.runner.parallel_config),
            head_dim=self.runner.model_config.get_head_size(),
909
910
911
912
913
            page_size=self.block_size,
            seq_start_loc=seq_start_loc,
            query_start_loc=query_start_loc,
            device=device,
            data_type=kv_cache_dtype,
914
            q_data_type=self.runner.model_config.dtype,
915
            use_cuda_graph=use_captured_graph,
916
917
918
919
920
            is_profile_run=self.is_profile_run,
            window_left=self.window_left,
            logits_soft_cap=self.logits_soft_cap,
            sm_scale=self.sm_scale,
        )
921
922


923
924
925
926
927
928
929
class FlashInferImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
930
931
932
933
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
934
        blocksparse_params: Optional[Dict[str, Any]] = None,
935
        logits_soft_cap: Optional[float] = None,
936
        attn_type: str = AttentionType.DECODER,
937
        kv_sharing_target_layer_name: Optional[str] = None,
938
        use_irope: bool = False,
939
    ) -> None:
940
941
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported in V0.")
942
943
944
945
        if use_irope:
            logger.warning_once(
                "Using irope in FlashInfer is not supported yet, it will fall"
                " back to global attention for long context.")
946
947
        self.num_heads = num_heads
        self.head_size = head_size
948
        self.scale = float(scale)
949
        self.num_kv_heads = num_kv_heads
950
951
952
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
953
954
        self.sliding_window = ((sliding_window - 1,
                                0) if sliding_window is not None else (-1, -1))
955
        self.kv_cache_dtype = kv_cache_dtype
956
        self.logits_soft_cap = logits_soft_cap
957

958
959
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

960
961
962
963
964
965
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashInferImpl")

966
967
    def forward(
        self,
968
        layer: AttentionLayer,
969
970
971
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
972
        kv_cache: torch.Tensor,
973
        attn_metadata: FlashInferMetadata,
974
        output: Optional[torch.Tensor] = None,
975
        output_scale: Optional[torch.Tensor] = None,
976
    ) -> torch.Tensor:
977

978
979
980
981
982
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashInferImpl")

983
        # TODO: directly write to output tensor
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
        num_heads: int = self.num_heads
        head_size: int = self.head_size
        num_kv_heads: int = self.num_kv_heads
        kv_cache_dtype: str = self.kv_cache_dtype
        softmax_scale: float = self.scale
        window_size = self.sliding_window
        alibi_slopes = self.alibi_slopes
        logits_soft_cap = self.logits_soft_cap

        num_tokens, hidden_size = query.shape
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_kv_heads, head_size)
        value = value.view(-1, num_kv_heads, head_size)

        if kv_cache.numel() > 0:
            # Use the same reshape and cache kernel as flash attention.
            ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype,
1007
1008
                layer._k_scale,
                layer._v_scale,
1009
            )
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                    kv_cache_dtype)
                kv_cache = kv_cache.view(torch_dtype)

        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
                    f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
                    f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
        query = query.contiguous(
        )  # Flashinfer requires query to be contiguous
        # Query for decode. KV is not needed because it is already cached.
        # QKV for prefill.
        decode_query = query[num_prefill_tokens:]
        query = query[:num_prefill_tokens]

        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        window_left = window_size[0] if window_size is not None else -1

        prefill_output: Optional[torch.Tensor] = None
        decode_output: Optional[torch.Tensor] = None
1040
        stride_order = FlashInferBackend.get_kv_cache_stride_order()
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
        if prefill_meta := attn_metadata.prefill_metadata:
            # We will use flash attention for prefill
            # when kv_cache is not provided.
            # This happens when vllm runs the profiling to
            # determine the number of blocks.
            if kv_cache.numel() == 0:
                prefill_output = flash_attn_varlen_func(
                    q=query,
                    k=key,
                    v=value,
                    cu_seqlens_q=prefill_meta.seq_start_loc,
                    cu_seqlens_k=prefill_meta.seq_start_loc,
                    max_seqlen_q=prefill_meta.max_prefill_seq_len,
                    max_seqlen_k=prefill_meta.max_prefill_seq_len,
                    softmax_scale=softmax_scale,
                    causal=True,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                )
            else:
                assert prefill_meta is not None
                assert prefill_meta.prefill_wrapper is not None
1063
1064
1065
1066
1067
1068
1069
1070

                assert prefill_meta.prefill_wrapper._causal
                assert prefill_meta.prefill_wrapper._window_left == window_left
                assert prefill_meta.prefill_wrapper._logits_soft_cap == (
                    logits_soft_cap or 0.0)
                assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale

                prefill_output = prefill_meta.prefill_wrapper.run(
1071
                    query,
1072
                    kv_cache.permute(*stride_order),
1073
1074
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
1075
                )
1076
1077
1078
        if decode_meta := attn_metadata.decode_metadata:
            assert decode_meta is not None
            assert decode_meta.decode_wrapper is not None
1079
1080
1081
1082
1083
1084
1085

            assert decode_meta.decode_wrapper._window_left == window_left
            assert decode_meta.decode_wrapper._logits_soft_cap == (
                logits_soft_cap or 0.0)
            assert decode_meta.decode_wrapper._sm_scale == softmax_scale

            decode_output = decode_meta.decode_wrapper.run(
1086
                decode_query,
1087
                kv_cache.permute(*stride_order),
1088
1089
                k_scale=layer._k_scale_float,
                v_scale=layer._v_scale_float,
1090
            )
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

        if prefill_output is None and decode_output is not None:
            # Decode only batch.
            output, num_tokens = decode_output, num_decode_tokens
        elif decode_output is None and prefill_output is not None:
            # Prefill only batch.
            output, num_tokens = prefill_output, num_prefill_tokens
        else:
            # Chunked prefill batch does not work with speculative decoding in
            # FlashInfer backend, so the query length for decode should be 1.
            assert prefill_output is not None
            assert decode_output is not None
            assert decode_meta is not None
            assert decode_meta.decode_query_len == 1
            decode_output = decode_output.squeeze(1)
            output = torch.cat([prefill_output, decode_output], dim=0)
        return output.view(num_tokens, hidden_size)