flashinfer.py 38.7 KB
Newer Older
1
from collections import defaultdict
2
from contextlib import contextmanager
3
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
5

6
7
from vllm.multimodal import MultiModalPlaceholderMap

8
9
try:
    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
10
    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
11
    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
12

13
    from vllm.vllm_flash_attn import flash_attn_varlen_func
14
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
15
16
except ImportError:
    BatchDecodeWithPagedKVCacheWrapper = None
17
    CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
18
    BatchPrefillWithPagedKVCacheWrapper = None
19
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
20

21
22
import torch

23
import vllm.envs as envs
24
25
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
26
27
                                              AttentionMetadata,
                                              AttentionMetadataBuilder,
28
                                              AttentionState, AttentionType)
29
30
31
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
                                           compute_slot_mapping_start_idx,
                                           is_block_tables_empty)
32
from vllm.attention.ops.paged_attn import PagedAttention
33
34
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
                        make_tensor_with_pad)
35
36

if TYPE_CHECKING:
37
38
    from vllm.worker.model_runner import (ModelInputForGPUBuilder,
                                          ModelInputForGPUWithSamplingMetadata)
39
40
41
42


class FlashInferBackend(AttentionBackend):

43
44
    @staticmethod
    def get_name() -> str:
45
        return "FLASHINFER"
46

47
48
49
50
51
    @staticmethod
    def get_impl_cls() -> Type["FlashInferImpl"]:
        return FlashInferImpl

    @staticmethod
52
53
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return FlashInferMetadata
54

55
56
57
58
    @staticmethod
    def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
        return FlashInferMetadataBuilder

59
60
61
62
    @staticmethod
    def get_state_cls() -> Type["FlashInferState"]:
        return FlashInferState

63
64
65
66
67
68
69
70
71
72
73
74
75
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
76
        src_to_dst: torch.Tensor,
77
    ) -> None:
78
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
79
80
81
82

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
83
        src_to_dists: torch.Tensor,
84
    ) -> None:
85
        PagedAttention.copy_blocks(kv_caches, src_to_dists)
86
87
88
89
90

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [64, 128, 256]

91
92
93
94
95
96
97
98
99
    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class FlashInferState(AttentionState):

    def __init__(self, runner):
        self.runner = runner
        self._is_graph_capturing = False
        self._workspace_buffer = None
        self._decode_wrapper = None
        self._prefill_wrapper = None

    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
            self._workspace_buffer = torch.empty(
                FLASHINFER_WORKSPACE_BUFFER_SIZE,
                dtype=torch.uint8,
                device=self.runner.device)
        return self._workspace_buffer

    def _get_prefill_wrapper(self):
        if self._prefill_wrapper is None:
            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                self._get_workspace_buffer(), "NHD")
        return self._prefill_wrapper

    def _get_decode_wrapper(self):
        if self._decode_wrapper is None:
            num_qo_heads = (self.runner.model_config.get_num_attention_heads(
                self.runner.parallel_config))
            num_kv_heads = self.runner.model_config.get_num_kv_heads(
                self.runner.parallel_config)
130
131
            use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
                num_qo_heads // num_kv_heads > 4)
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
                self._get_workspace_buffer(),
                "NHD",
                use_tensor_cores=use_tensor_cores)
        return self._decode_wrapper

    @contextmanager
    def graph_capture(self, max_batch_size: int):
        self._is_graph_capturing = True
        self._graph_decode_wrapper = None
        self._graph_slot_mapping = torch.full((max_batch_size, ),
                                              PAD_SLOT_ID,
                                              dtype=torch.long,
                                              device=self.runner.device)
        self._graph_seq_lens = torch.ones(max_batch_size,
                                          dtype=torch.int32,
                                          device=self.runner.device)
        self._graph_block_tables = torch.from_numpy(
            self.runner.graph_block_tables).to(device=self.runner.device)
        self._graph_decode_workspace_buffer = self._get_workspace_buffer()
        self._graph_indices_buffer = torch.empty(
            max_batch_size * self.runner.cache_config.num_gpu_blocks,
            dtype=torch.int32,
            device=self.runner.device)
        self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
                                                dtype=torch.int32,
                                                device=self.runner.device)
        self._graph_last_page_len_buffer = torch.empty(
            max_batch_size, dtype=torch.int32, device=self.runner.device)
        yield
        self._is_graph_capturing = False
        del self._graph_slot_mapping
        del self._graph_seq_lens
        del self._graph_block_tables
        del self._graph_decode_workspace_buffer
        del self._graph_indices_buffer
        del self._graph_indptr_buffer
        del self._graph_last_page_len_buffer
        del self._graph_decode_wrapper

    def graph_clone(self, batch_size: int):
        assert self._is_graph_capturing
        state = self.__class__(self.runner)
        state._workspace_buffer = self._graph_decode_workspace_buffer
        state._decode_wrapper = self._graph_decode_wrapper
        state._prefill_wrapper = self._get_prefill_wrapper()
        return state

180
181
    def graph_capture_get_metadata_for_batch(
            self, batch_size: int, is_encoder_decoder_model: bool = False):
182
183
184
185
186
187
188
189
        assert self._is_graph_capturing
        _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
        _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]

        num_qo_heads = (self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config))
        num_kv_heads = self.runner.model_config.get_num_kv_heads(
            self.runner.parallel_config)
190
191
        use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
            num_qo_heads // num_kv_heads > 4)
192
193
194
195
196
        self._graph_decode_wrapper = \
            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
            self._graph_decode_workspace_buffer, _indptr_buffer,
            self._graph_indices_buffer, _last_page_len_buffer, "NHD",
            use_tensor_cores)
197
198
199
200
201
202
        if self.runner.kv_cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.runner.kv_cache_dtype)
        else:
            kv_cache_dtype = get_kv_cache_torch_dtype(
                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

        paged_kv_indptr_tensor_host = torch.arange(0,
                                                   batch_size + 1,
                                                   dtype=torch.int32)
        paged_kv_indices_tensor_host = torch.arange(0,
                                                    batch_size,
                                                    dtype=torch.int32)
        paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
                                                        self.runner.block_size,
                                                        dtype=torch.int32)
        query_start_loc_host = torch.arange(0,
                                            batch_size + 1,
                                            dtype=torch.int32)

        attn_metadata = self.runner.attn_backend.make_metadata(
            num_prefills=0,
            slot_mapping=self._graph_slot_mapping[:batch_size],
220
            multi_modal_placeholder_index_maps=None,
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            max_prefill_seq_len=0,
            block_tables=self._graph_block_tables,
            paged_kv_indptr=paged_kv_indptr_tensor_host,
            paged_kv_indices=paged_kv_indices_tensor_host,
            paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim=self.runner.model_config.get_head_size(),
            page_size=self.runner.block_size,
            seq_start_loc=None,
            query_start_loc=query_start_loc_host,
            device=self.runner.device,
            data_type=kv_cache_dtype,
236
            q_data_type=self.runner.model_config.dtype,
237
238
239
240
241
242
            use_cuda_graph=True,
            decode_wrapper=self._graph_decode_wrapper,
            prefill_wrapper=None)
        attn_metadata.begin_forward()
        return attn_metadata

243
244
245
    def get_graph_input_buffers(self,
                                attn_metadata,
                                is_encoder_decoder_model: bool = False):
246
247
248
249
        return {
            "slot_mapping": attn_metadata.slot_mapping,
        }

250
251
252
253
    def prepare_graph_input_buffers(self,
                                    input_buffers,
                                    attn_metadata,
                                    is_encoder_decoder_model: bool = False):
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        return

    def begin_forward(self, model_input):
        assert not self._is_graph_capturing
        state = self
        if model_input.attn_metadata.use_cuda_graph:
            batch_size = model_input.input_tokens.shape[0]
            state = (self.runner.graph_runners[model_input.virtual_engine]
                     [batch_size].attn_state)
        model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
        )
        model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
        model_input.attn_metadata.begin_forward()


269
@dataclass
270
271
272
273
class FlashInferMetadata(AttentionMetadata):
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
274
275
276
277
278
    # Number of query tokens for each request in the batch.
    # Currently, we require that all requests have the same number of query
    # tokens during the decoding phase. When speculavie decoding is enabled,
    # decode_query_len might be greater than 1. In all other cases, it is 1.
    decode_query_len: Optional[int] = 1
279

280
    use_cuda_graph: bool = True
281

282
    prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
283
284
    decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None

285
    # Metadata for the prefill stage
286
    seq_start_loc: Optional[torch.Tensor] = None
287
    query_start_loc: Optional[torch.Tensor] = None
288
289
    block_tables: Optional[torch.Tensor] = None

290
291
292
293
    # used for GPU in-place advance_step
    seq_lens_tensor: Optional[torch.Tensor] = None
    block_table_bound: Optional[torch.Tensor] = None

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    # An example for paged_kv_indices, paged_kv_indptr:
    # request 1, page indices [0, 5, 8]
    # request 2, page indices [1, 6, 7]
    # request 3, page indices [3, 4]
    # paged_kv_indices is a concatenation of page indices of all requests:
    # [0, 5, 8, 1, 6, 7, 3, 4]
    # paged_kv_indptr is used to index into paged_kv_indices:
    # [0, 3, 6, 8]
    # The indptr of the paged kv cache, shape: [batch_size + 1]
    paged_kv_indptr: Optional[torch.Tensor] = None
    # The page indices of the paged kv cache
    paged_kv_indices: Optional[torch.Tensor] = None
    # The number of entries in the last page of each request in
    # the paged kv cache, shape: [batch_size]
    paged_kv_last_page_len: Optional[torch.Tensor] = None
    # The number of query/output heads
    num_qo_heads: Optional[int] = None
    # The number of key/value heads
    num_kv_heads: Optional[int] = None
    # The dimension of the attention heads
    head_dim: Optional[int] = None
    # Block size of vllm
    page_size: Optional[int] = None
    # The data type of the paged kv cache
    data_type: torch.dtype = None
319
320
    # The data type of the query
    q_data_type: torch.dtype = None
321
    device: torch.device = torch.device("cuda")
322
    is_profile_run: bool = False
323
324
325
326
327
328
329
330
331
332
333

    def __post_init__(self):
        # Refer to
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
        if self.head_dim is not None and self.head_dim \
                not in supported_head_sizes:
            raise ValueError(
                f"Only {supported_head_sizes} are supported for head_dim,",
                f"received {self.head_dim}.")

334
335
336
337
338
339
    def begin_forward(self):
        if self.num_prefill_tokens > 0:
            if self.paged_kv_indices is None:
                return

            assert self.prefill_wrapper is not None
340
            assert self.query_start_loc is not None
341
342
343
            assert self.paged_kv_indices is not None
            assert self.paged_kv_indptr is not None
            assert self.paged_kv_last_page_len is not None
344
345
            assert self.block_table_bound is not None
            assert self.seq_lens_tensor is not None
346
            self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
347
348
            batch_size = self.query_start_loc.shape[0] - 1
            assert batch_size >= 0
349
350
351
352
            # We will use flash attention for profiling to
            # determine the number of blocks. Therefore,
            # we don't need to prepare the input for flashinfer for profile run.
            if not self.is_profile_run:
353
                self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
354
355
                self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                    self.device)
356
357
                self.block_table_bound = self.block_table_bound.to(self.device)
                self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
358
359
360
                self.paged_kv_indices = self.paged_kv_indices.to(self.device)
                self.prefill_wrapper.end_forward()
                self.prefill_wrapper.begin_forward(
361
362
363
364
                    self.query_start_loc,
                    self.paged_kv_indptr[:self.num_prefills + 1],
                    self.paged_kv_indices,
                    self.paged_kv_last_page_len[:self.num_prefills],
365
366
                    self.num_qo_heads, self.num_kv_heads, self.head_dim,
                    self.page_size)
367
        if self.num_decode_tokens > 0:
368
369
370
371
372
373
374
375
376
377
378
379
            assert self.paged_kv_indices is not None
            assert self.paged_kv_indptr is not None
            assert self.paged_kv_last_page_len is not None
            self.paged_kv_indices = self.paged_kv_indices.to(self.device)
            self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
            self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                self.device)
            # handle model warmup path
            if self.block_table_bound is not None:
                self.block_table_bound = self.block_table_bound.to(self.device)
            if self.seq_lens_tensor is not None:
                self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
380
381

            assert self.decode_wrapper is not None
382
            self.decode_wrapper.end_forward()
383
            self.decode_wrapper.begin_forward(
384
                self.paged_kv_indptr[self.num_prefills:],
385
                self.paged_kv_indices,
386
                self.paged_kv_last_page_len[self.num_prefills:],
387
388
389
390
391
392
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
                self.page_size,
                # Disable flashinfer's pos encoding and use vllm's rope.
                pos_encoding_mode="NONE",
393
394
395
396
                # kv-cache data type.
                data_type=self.data_type,
                # query data type.
                q_data_type=self.q_data_type)
397
398
399
400
401
402

    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
        if skip_fields is None:
            skip_fields = set()
403
        # We need to skip the prefill/decode_wrapper field since it cannot be
404
        # broadcasted with nccl when TP is enabled.
405
        skip_fields.add('prefill_wrapper')
406
407
408
        skip_fields.add('decode_wrapper')
        return super().asdict_zerocopy(skip_fields)

409
410
    @property
    def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
411
412
413
        if self.num_prefills == 0:
            return None
        return self
414
415
416

    @property
    def decode_metadata(self) -> Optional["FlashInferMetadata"]:
417
        if self.num_decode_tokens == 0:
418
419
420
            return None
        return self

421
422
423
424
425
426
427
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
                     sampled_token_ids: Optional[torch.Tensor],
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
428
429
430
431
        """
        Update metadata in-place to advance one decode step.
        """

432
433
434
435
436
        assert not turn_prefills_into_decodes, \
            ("Chunked prefill is not supported with flashinfer yet."
             "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
             "specific parameter.")

437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        assert num_seqs > 0
        assert num_queries > 0
        assert model_input.attn_metadata is not None
        assert sampled_token_ids is not None

        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()

        # Update GPU tensors
        ops.advance_step_flashinfer(
            num_seqs=num_seqs,
            num_queries=num_queries,
            block_size=block_size,
            input_tokens=model_input.input_tokens,
            sampled_token_ids=model_input.input_tokens,
            input_positions=model_input.input_positions,
            seq_lens=self.seq_lens_tensor,
            slot_mapping=self.slot_mapping,
            block_tables=self.block_tables,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            block_table_bound=self.block_table_bound)

467

468
469
470
471
472
473
474
475
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
        self.slot_mapping: List[int] = []
        self.prefill_seq_lens: List[int] = []
        self.context_lens: List[int] = []
        self.block_tables: List[List[int]] = []
        self.curr_seq_lens: List[int] = []
476
477
478
        self.multimodal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
479
480
481
482
        self.num_prefills = 0
        self.num_prefill_tokens = 0
        self.num_decode_tokens = 0

483
484
485
        self.input_builder = input_builder
        self.runner = input_builder.runner

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
        self.sliding_window = input_builder.sliding_window
        self.block_size = input_builder.block_size

        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        self.paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        self.paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        self.paged_kv_last_page_len: List[int] = []
505
        self.total_blocks = 0
506
507
        self.is_profile_run: bool = False

508
509
510
    def _add_seq_group(
            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
            chunked_prefill_enabled: bool):
511
512
513
514
515
        """Add a sequence group to the metadata. Specifically update/append
        1. context length.
        2. block table.
        3. slot mapping.
        """
516
517
518
        is_prompt = inter_data.is_prompt
        block_tables = inter_data.block_tables
        computed_block_nums = inter_data.computed_block_nums
519
520
521

        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
             curr_sliding_window_block) in zip(
522
523
524
525
                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
                 inter_data.orig_seq_lens, inter_data.seq_lens,
                 inter_data.query_lens, inter_data.context_lens,
                 inter_data.curr_sliding_window_blocks):
526
527
            self.context_lens.append(context_len)
            if is_prompt:
528
529
530
531
532
                mm_maps = inter_data.multi_modal_placeholder_maps
                if mm_maps:
                    for modality, placeholders in mm_maps.items():
                        self.multimodal_placeholder_maps[modality].extend(
                            placeholders)
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
                self.num_prefills += 1
                self.num_prefill_tokens += token_len
                self.prefill_seq_lens.append(seq_len)
            else:
                assert query_len == 1, (
                    "seq_len: {}, context_len: {}, query_len: {}".format(
                        seq_len, context_len, query_len))
                self.num_decode_tokens += query_len
                self.curr_seq_lens.append(curr_seq_len)

            # Compute block table.
            # TODO(sang): Combine chunked prefill and prefix caching by
            # only allowing multiple of block_size chunk size.
            # NOTE: This only works for oooooooxxx style attention.
            block_table = []
548
            if inter_data.prefix_cache_hit:
549
550
551
552
553
554
555
556
557
                block_table = computed_block_nums
            elif ((chunked_prefill_enabled or not is_prompt)
                  and block_tables is not None):
                block_table = block_tables[seq_id][-curr_sliding_window_block:]
            self.block_tables.append(block_table)

            is_profile_run = is_block_tables_empty(block_tables)

            # Compute slot mapping.
558
559
560
            start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
                                                       context_len,
                                                       self.sliding_window)
561
562
            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                                 seq_len, context_len, start_idx,
563
                                 self.block_size, inter_data.block_tables)
564
565
566
567
568

            # It is not necessary to add paged_kv_indices, paged_kv_indptr,
            # and paged_kv_last_page_len for profile run because we will
            # create dummy inputs.
            if is_profile_run:
569
                self.is_profile_run = is_profile_run
570
571
572
                return

            block_table = block_tables[seq_id]
573
574
575
576
577
578
579
580
            self._update_paged_kv_tensors(block_table, seq_len)

    def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
        # Get the number of valid blocks based on sequence length.
        # If seq_len = 16, block_size = 16,
        # block_table_bound is 1 with 1 valid block.
        # If seq_len = 15, block_size = 16,
        # block_table_bound is 0 + 1 with 1 valid block.
581
        self.total_blocks += len(block_table)
582
583
584
585
586
587
588
589
590
591
592
        block_table_bound = seq_len // self.block_size + 1 \
                            if seq_len % self.block_size != 0 \
                            else seq_len // self.block_size
        self.paged_kv_indices.extend(block_table[:block_table_bound])
        self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
                                    block_table_bound)

        last_page_len = seq_len % self.block_size
        if last_page_len == 0:
            last_page_len = self.block_size
        self.paged_kv_last_page_len.append(last_page_len)
593

594
    def build(self, seq_lens: List[int], query_lens: List[int],
595
              cuda_graph_pad_size: int, batch_size: int):
596
597
598
599
600
601
602
603
604
        """Build attention metadata with on-device tensors.

        Args:
            seq_lens: The maybe padded sequence lengths of the input sequences.
            query_lens: The query lengths of the input sequences.
            cuda_graph_pad_size: The padding size for cuda graph.
                                 -1 if cuda graph is not used.
            batch_size: The maybe padded batch size.
        """
605
606
607
608
609
        for inter_data in self.input_builder.inter_data_list:
            self._add_seq_group(inter_data,
                                self.input_builder.chunked_prefill_enabled)

        device = self.runner.device
610
611
612
613
        use_captured_graph = cuda_graph_pad_size != -1

        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
        num_decode_tokens = self.num_decode_tokens
614
        decode_query_len = max(query_lens[self.num_prefills:], default=1)
615
616
617
618

        if use_captured_graph:
            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
            self.block_tables.extend([] * cuda_graph_pad_size)
619
            num_decode_tokens = batch_size - self.num_prefill_tokens
620
621
622

            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
623
            input_block_tables = self.runner.graph_block_tables[:batch_size]
624
            max_blocks = input_block_tables.shape[1]
625
626
            for i, block_table in enumerate(self.block_tables):
                if block_table:
627
628
629
630
631
632
633
634
635
636
                    num_blocks = len(block_table)
                    if num_blocks <= max_blocks:
                        input_block_tables[i, :num_blocks] = block_table
                    else:
                        # It may be possible to have more blocks allocated due
                        # to lookahead slots of multi-step, however, they are
                        # not used anyway, so can be safely ignored.
                        input_block_tables[
                            i, :max_blocks] = block_table[:max_blocks]

637
638
            block_tables = torch.from_numpy(input_block_tables).to(
                device, non_blocking=True)
639
640
641
642
643
644
645
646
647
648
649
650
651

            last_paged_kv_indptr = self.paged_kv_indptr[-1]
            self.paged_kv_indptr.extend([last_paged_kv_indptr] *
                                        cuda_graph_pad_size)
            self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
        else:
            block_tables = make_tensor_with_pad(
                self.block_tables,
                pad=0,
                dtype=torch.int,
                device=device,
            )

652
653
654
655
656
657
658
        assert device is not None
        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
                                           self.runner.pin_memory)
        query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
                                             self.runner.pin_memory)
        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
                                               device, self.runner.pin_memory)
659
660
661
662
663
664
        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                      dtype=torch.int32,
                                      device=device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=device)
665
666
667
668
669
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            self.multimodal_placeholder_maps.items()
        }
670
671
672
673
674
675
676
677
678
679
        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=query_start_loc.dtype,
                     out=query_start_loc[1:])

        if len(self.paged_kv_indptr) > 0:
680
681
682
683
            # extend to the maximum number of blocks as returned by the
            # scheduler
            self.paged_kv_indices.extend(
                [0] * (self.total_blocks - len(self.paged_kv_indices)))
684
685
686
687
688
689
690
691
            paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
                                                   device="cpu",
                                                   dtype=torch.int)
            paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
                                                  device="cpu",
                                                  dtype=torch.int)
            paged_kv_last_page_len_tensor = torch.tensor(
                self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
692
693
694
695
            block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
                                                   1,
                                                   device="cpu",
                                                   dtype=torch.int)
696
697
698
699
        else:
            paged_kv_indices_tensor = None
            paged_kv_indptr_tensor = None
            paged_kv_last_page_len_tensor = None
700
            block_table_bound_tensor = None
701

702
703
704
705
706
707
        if self.runner.kv_cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.runner.kv_cache_dtype)
        else:
            kv_cache_dtype = get_kv_cache_torch_dtype(
                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
708

709
        return FlashInferMetadata(
710
            decode_query_len=decode_query_len,
711
712
            num_prefills=self.num_prefills,
            slot_mapping=slot_mapping_tensor,
713
            multi_modal_placeholder_index_maps=placeholder_index_maps,
714
715
716
717
718
719
720
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            max_prefill_seq_len=max_prefill_seq_len,
            block_tables=block_tables,
            paged_kv_indptr=paged_kv_indptr_tensor,
            paged_kv_indices=paged_kv_indices_tensor,
            paged_kv_last_page_len=paged_kv_last_page_len_tensor,
721
722
            block_table_bound=block_table_bound_tensor,
            seq_lens_tensor=seq_lens_tensor,
723
724
725
726
727
            num_qo_heads=self.runner.model_config.get_num_attention_heads(
                self.runner.parallel_config),
            num_kv_heads=self.runner.model_config.get_num_kv_heads(
                self.runner.parallel_config),
            head_dim=self.runner.model_config.get_head_size(),
728
729
730
731
732
            page_size=self.block_size,
            seq_start_loc=seq_start_loc,
            query_start_loc=query_start_loc,
            device=device,
            data_type=kv_cache_dtype,
733
            q_data_type=self.runner.model_config.dtype,
734
735
            use_cuda_graph=use_captured_graph,
            is_profile_run=self.is_profile_run)
736
737


738
739
740
741
742
743
744
class FlashInferImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
745
746
747
748
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
749
        blocksparse_params: Optional[Dict[str, Any]] = None,
750
        logits_soft_cap: Optional[float] = None,
751
752
753
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
754
        self.scale = float(scale)
755
        self.num_kv_heads = num_kv_heads
756
757
758
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
759
760
        self.sliding_window = ((sliding_window - 1,
                                0) if sliding_window is not None else (-1, -1))
761
        self.kv_cache_dtype = kv_cache_dtype
762
        self.logits_soft_cap = logits_soft_cap
763

764
765
766
767
768
769
770
771
        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
772
        kv_cache: torch.Tensor,
773
        attn_metadata: FlashInferMetadata,
774
775
        k_scale: float = 1.0,
        v_scale: float = 1.0,
776
        attn_type: str = AttentionType.DECODER,
777
    ) -> torch.Tensor:
778
779
780
781
782
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashInferImpl")
783

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
        num_heads: int = self.num_heads
        head_size: int = self.head_size
        num_kv_heads: int = self.num_kv_heads
        kv_cache_dtype: str = self.kv_cache_dtype
        softmax_scale: float = self.scale
        window_size = self.sliding_window
        alibi_slopes = self.alibi_slopes
        logits_soft_cap = self.logits_soft_cap

        num_tokens, hidden_size = query.shape
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_kv_heads, head_size)
        value = value.view(-1, num_kv_heads, head_size)

        if kv_cache.numel() > 0:
            # Use the same reshape and cache kernel as flash attention.
            ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype,
                k_scale,
                v_scale,
809
            )
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                    kv_cache_dtype)
                kv_cache = kv_cache.view(torch_dtype)

        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
                    f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
                    f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
        query = query.contiguous(
        )  # Flashinfer requires query to be contiguous
        # Query for decode. KV is not needed because it is already cached.
        # QKV for prefill.
        decode_query = query[num_prefill_tokens:]
        query = query[:num_prefill_tokens]

        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        window_left = window_size[0] if window_size is not None else -1

        prefill_output: Optional[torch.Tensor] = None
        decode_output: Optional[torch.Tensor] = None
        if prefill_meta := attn_metadata.prefill_metadata:
            # We will use flash attention for prefill
            # when kv_cache is not provided.
            # This happens when vllm runs the profiling to
            # determine the number of blocks.
            if kv_cache.numel() == 0:
                prefill_output = flash_attn_varlen_func(
                    q=query,
                    k=key,
                    v=value,
                    cu_seqlens_q=prefill_meta.seq_start_loc,
                    cu_seqlens_k=prefill_meta.seq_start_loc,
                    max_seqlen_q=prefill_meta.max_prefill_seq_len,
                    max_seqlen_k=prefill_meta.max_prefill_seq_len,
                    softmax_scale=softmax_scale,
                    causal=True,
                    window_size=window_size,
                    alibi_slopes=alibi_slopes,
                )
            else:
                assert prefill_meta is not None
                assert prefill_meta.prefill_wrapper is not None
                prefill_output = prefill_meta.prefill_wrapper.forward(
                    query,
                    kv_cache,
                    logits_soft_cap=logits_soft_cap,
                    causal=True,
                    k_scale=k_scale,
                    v_scale=v_scale,
                    window_left=window_left)
        if decode_meta := attn_metadata.decode_metadata:
            assert decode_meta is not None
            assert decode_meta.decode_wrapper is not None
            decode_output = decode_meta.decode_wrapper.forward(
                decode_query,
875
                kv_cache,
876
                sm_scale=softmax_scale,
877
878
                logits_soft_cap=logits_soft_cap,
                k_scale=k_scale,
879
880
                v_scale=v_scale,
                window_left=window_left)
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897

        if prefill_output is None and decode_output is not None:
            # Decode only batch.
            output, num_tokens = decode_output, num_decode_tokens
        elif decode_output is None and prefill_output is not None:
            # Prefill only batch.
            output, num_tokens = prefill_output, num_prefill_tokens
        else:
            # Chunked prefill batch does not work with speculative decoding in
            # FlashInfer backend, so the query length for decode should be 1.
            assert prefill_output is not None
            assert decode_output is not None
            assert decode_meta is not None
            assert decode_meta.decode_query_len == 1
            decode_output = decode_output.squeeze(1)
            output = torch.cat([prefill_output, decode_output], dim=0)
        return output.view(num_tokens, hidden_size)