flashinfer.py 35 KB
Newer Older
1
from contextlib import contextmanager
2
from dataclasses import dataclass
3
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
4

5
6
try:
    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
7
    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
8
    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
9
10

    import vllm.attention.backends.flash_attn  # noqa
11
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
12
13
except ImportError:
    BatchDecodeWithPagedKVCacheWrapper = None
14
    CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
15
    BatchPrefillWithPagedKVCacheWrapper = None
16
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
17

18
19
20
21
import torch

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
22
23
                                              AttentionMetadata,
                                              AttentionMetadataBuilder,
24
                                              AttentionState, AttentionType)
25
26
27
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
                                           compute_slot_mapping_start_idx,
                                           is_block_tables_empty)
28
from vllm.attention.ops.paged_attn import PagedAttention
29
30
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
                        make_tensor_with_pad)
31
32

if TYPE_CHECKING:
33
34
    from vllm.worker.model_runner import (ModelInputForGPUBuilder,
                                          ModelInputForGPUWithSamplingMetadata)
35
36
37
38


class FlashInferBackend(AttentionBackend):

39
40
41
42
    @staticmethod
    def get_name() -> str:
        return "flashinfer"

43
44
45
46
47
    @staticmethod
    def get_impl_cls() -> Type["FlashInferImpl"]:
        return FlashInferImpl

    @staticmethod
48
49
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return FlashInferMetadata
50

51
52
53
54
    @staticmethod
    def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
        return FlashInferMetadataBuilder

55
56
57
58
    @staticmethod
    def get_state_cls() -> Type["FlashInferState"]:
        return FlashInferState

59
60
61
62
63
64
65
66
67
68
69
70
71
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
72
        src_to_dst: torch.Tensor,
73
    ) -> None:
74
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
75
76
77
78

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
79
        src_to_dists: torch.Tensor,
80
    ) -> None:
81
        PagedAttention.copy_blocks(kv_caches, src_to_dists)
82
83
84
85
86

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [64, 128, 256]

87
88
89
90
91
92
93
94
95
    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

96

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class FlashInferState(AttentionState):

    def __init__(self, runner):
        self.runner = runner
        self._is_graph_capturing = False
        self._workspace_buffer = None
        self._decode_wrapper = None
        self._prefill_wrapper = None

    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
            self._workspace_buffer = torch.empty(
                FLASHINFER_WORKSPACE_BUFFER_SIZE,
                dtype=torch.uint8,
                device=self.runner.device)
        return self._workspace_buffer

    def _get_prefill_wrapper(self):
        if self._prefill_wrapper is None:
            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                self._get_workspace_buffer(), "NHD")
        return self._prefill_wrapper

    def _get_decode_wrapper(self):
        if self._decode_wrapper is None:
            num_qo_heads = (self.runner.model_config.get_num_attention_heads(
                self.runner.parallel_config))
            num_kv_heads = self.runner.model_config.get_num_kv_heads(
                self.runner.parallel_config)
Cody Yu's avatar
Cody Yu committed
126
            use_tensor_cores = num_qo_heads // num_kv_heads > 4
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
            self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
                self._get_workspace_buffer(),
                "NHD",
                use_tensor_cores=use_tensor_cores)
        return self._decode_wrapper

    @contextmanager
    def graph_capture(self, max_batch_size: int):
        self._is_graph_capturing = True
        self._graph_decode_wrapper = None
        self._graph_slot_mapping = torch.full((max_batch_size, ),
                                              PAD_SLOT_ID,
                                              dtype=torch.long,
                                              device=self.runner.device)
        self._graph_seq_lens = torch.ones(max_batch_size,
                                          dtype=torch.int32,
                                          device=self.runner.device)
        self._graph_block_tables = torch.from_numpy(
            self.runner.graph_block_tables).to(device=self.runner.device)
        self._graph_decode_workspace_buffer = self._get_workspace_buffer()
        self._graph_indices_buffer = torch.empty(
            max_batch_size * self.runner.cache_config.num_gpu_blocks,
            dtype=torch.int32,
            device=self.runner.device)
        self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
                                                dtype=torch.int32,
                                                device=self.runner.device)
        self._graph_last_page_len_buffer = torch.empty(
            max_batch_size, dtype=torch.int32, device=self.runner.device)
        yield
        self._is_graph_capturing = False
        del self._graph_slot_mapping
        del self._graph_seq_lens
        del self._graph_block_tables
        del self._graph_decode_workspace_buffer
        del self._graph_indices_buffer
        del self._graph_indptr_buffer
        del self._graph_last_page_len_buffer
        del self._graph_decode_wrapper

    def graph_clone(self, batch_size: int):
        assert self._is_graph_capturing
        state = self.__class__(self.runner)
        state._workspace_buffer = self._graph_decode_workspace_buffer
        state._decode_wrapper = self._graph_decode_wrapper
        state._prefill_wrapper = self._get_prefill_wrapper()
        return state

    def graph_capture_get_metadata_for_batch(self, batch_size: int):
        assert self._is_graph_capturing
        _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
        _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]

        num_qo_heads = (self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config))
        num_kv_heads = self.runner.model_config.get_num_kv_heads(
            self.runner.parallel_config)
Cody Yu's avatar
Cody Yu committed
184
        use_tensor_cores = num_qo_heads // num_kv_heads > 4
185
186
187
188
189
        self._graph_decode_wrapper = \
            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
            self._graph_decode_workspace_buffer, _indptr_buffer,
            self._graph_indices_buffer, _last_page_len_buffer, "NHD",
            use_tensor_cores)
190
191
192
193
194
195
        if self.runner.kv_cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.runner.kv_cache_dtype)
        else:
            kv_cache_dtype = get_kv_cache_torch_dtype(
                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

        paged_kv_indptr_tensor_host = torch.arange(0,
                                                   batch_size + 1,
                                                   dtype=torch.int32)
        paged_kv_indices_tensor_host = torch.arange(0,
                                                    batch_size,
                                                    dtype=torch.int32)
        paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
                                                        self.runner.block_size,
                                                        dtype=torch.int32)
        query_start_loc_host = torch.arange(0,
                                            batch_size + 1,
                                            dtype=torch.int32)

        attn_metadata = self.runner.attn_backend.make_metadata(
            num_prefills=0,
            slot_mapping=self._graph_slot_mapping[:batch_size],
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            max_prefill_seq_len=0,
            block_tables=self._graph_block_tables,
            paged_kv_indptr=paged_kv_indptr_tensor_host,
            paged_kv_indices=paged_kv_indices_tensor_host,
            paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim=self.runner.model_config.get_head_size(),
            page_size=self.runner.block_size,
            seq_start_loc=None,
            query_start_loc=query_start_loc_host,
            device=self.runner.device,
            data_type=kv_cache_dtype,
228
            q_data_type=self.runner.model_config.dtype,
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            use_cuda_graph=True,
            decode_wrapper=self._graph_decode_wrapper,
            prefill_wrapper=None)
        attn_metadata.begin_forward()
        return attn_metadata

    def get_graph_input_buffers(self, attn_metadata):
        return {
            "slot_mapping": attn_metadata.slot_mapping,
        }

    def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
        return

    def begin_forward(self, model_input):
        assert not self._is_graph_capturing
        state = self
        if model_input.attn_metadata.use_cuda_graph:
            batch_size = model_input.input_tokens.shape[0]
            state = (self.runner.graph_runners[model_input.virtual_engine]
                     [batch_size].attn_state)
        model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
        )
        model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
        model_input.attn_metadata.begin_forward()


256
@dataclass
257
258
259
260
class FlashInferMetadata(AttentionMetadata):
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
261

262
    use_cuda_graph: bool = True
263

264
    prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
265
266
    decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None

267
    # Metadata for the prefill stage
268
    seq_start_loc: Optional[torch.Tensor] = None
269
    query_start_loc: Optional[torch.Tensor] = None
270
271
    block_tables: Optional[torch.Tensor] = None

272
273
274
275
    # used for GPU in-place advance_step
    seq_lens_tensor: Optional[torch.Tensor] = None
    block_table_bound: Optional[torch.Tensor] = None

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    # An example for paged_kv_indices, paged_kv_indptr:
    # request 1, page indices [0, 5, 8]
    # request 2, page indices [1, 6, 7]
    # request 3, page indices [3, 4]
    # paged_kv_indices is a concatenation of page indices of all requests:
    # [0, 5, 8, 1, 6, 7, 3, 4]
    # paged_kv_indptr is used to index into paged_kv_indices:
    # [0, 3, 6, 8]
    # The indptr of the paged kv cache, shape: [batch_size + 1]
    paged_kv_indptr: Optional[torch.Tensor] = None
    # The page indices of the paged kv cache
    paged_kv_indices: Optional[torch.Tensor] = None
    # The number of entries in the last page of each request in
    # the paged kv cache, shape: [batch_size]
    paged_kv_last_page_len: Optional[torch.Tensor] = None
    # The number of query/output heads
    num_qo_heads: Optional[int] = None
    # The number of key/value heads
    num_kv_heads: Optional[int] = None
    # The dimension of the attention heads
    head_dim: Optional[int] = None
    # Block size of vllm
    page_size: Optional[int] = None
    # The data type of the paged kv cache
    data_type: torch.dtype = None
301
302
    # The data type of the query
    q_data_type: torch.dtype = None
303
    device: torch.device = torch.device("cuda")
304
    is_profile_run: bool = False
305
306
307
308
309
310
311
312
313
314
315

    def __post_init__(self):
        # Refer to
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
        if self.head_dim is not None and self.head_dim \
                not in supported_head_sizes:
            raise ValueError(
                f"Only {supported_head_sizes} are supported for head_dim,",
                f"received {self.head_dim}.")

316
317
318
319
320
321
    def begin_forward(self):
        if self.num_prefill_tokens > 0:
            if self.paged_kv_indices is None:
                return

            assert self.prefill_wrapper is not None
322
            assert self.query_start_loc is not None
323
324
325
            assert self.paged_kv_indices is not None
            assert self.paged_kv_indptr is not None
            assert self.paged_kv_last_page_len is not None
326
327
            assert self.block_table_bound is not None
            assert self.seq_lens_tensor is not None
328
329
            batch_size = self.query_start_loc.shape[0] - 1
            assert batch_size >= 0
330
331
332
333
            # We will use flash attention for profiling to
            # determine the number of blocks. Therefore,
            # we don't need to prepare the input for flashinfer for profile run.
            if not self.is_profile_run:
334
                self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
335
336
                self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                    self.device)
337
338
                self.block_table_bound = self.block_table_bound.to(self.device)
                self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
339
340
341
342
343
344
345
                self.paged_kv_indices = self.paged_kv_indices.to(self.device)
                self.prefill_wrapper.end_forward()
                self.prefill_wrapper.begin_forward(
                    self.query_start_loc, self.paged_kv_indptr,
                    self.paged_kv_indices, self.paged_kv_last_page_len,
                    self.num_qo_heads, self.num_kv_heads, self.head_dim,
                    self.page_size)
346
        else:
347
348
349
350
351
352
353
354
355
356
357
358
            assert self.paged_kv_indices is not None
            assert self.paged_kv_indptr is not None
            assert self.paged_kv_last_page_len is not None
            self.paged_kv_indices = self.paged_kv_indices.to(self.device)
            self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
            self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                self.device)
            # handle model warmup path
            if self.block_table_bound is not None:
                self.block_table_bound = self.block_table_bound.to(self.device)
            if self.seq_lens_tensor is not None:
                self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
359
360

            assert self.decode_wrapper is not None
361
            self.decode_wrapper.end_forward()
362
363
364
365
366
367
368
369
370
371
            self.decode_wrapper.begin_forward(
                self.paged_kv_indptr,
                self.paged_kv_indices,
                self.paged_kv_last_page_len,
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
                self.page_size,
                # Disable flashinfer's pos encoding and use vllm's rope.
                pos_encoding_mode="NONE",
372
373
374
375
                # kv-cache data type.
                data_type=self.data_type,
                # query data type.
                q_data_type=self.q_data_type)
376
377
378
379
380
381

    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
        if skip_fields is None:
            skip_fields = set()
382
        # We need to skip the prefill/decode_wrapper field since it cannot be
383
        # broadcasted with nccl when TP is enabled.
384
        skip_fields.add('prefill_wrapper')
385
386
387
        skip_fields.add('decode_wrapper')
        return super().asdict_zerocopy(skip_fields)

388
389
390
391
392
393
394
395
396
397
398
399
400
    @property
    def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
        # Currently chunked prefill is not supported
        if self.num_decode_tokens == 0:
            assert self.num_prefills > 0
            return self

        return None

    @property
    def decode_metadata(self) -> Optional["FlashInferMetadata"]:
        # Currently chunked prefill is not supported
        if self.num_prefills > 0:
401
402
            assert self.num_decode_tokens == 0, (
                "Chunked prefill is not supported with flashinfer yet.")
403
404
405
406
            return None

        return self

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    def advance_step(
        self,
        model_input: "ModelInputForGPUWithSamplingMetadata",
        sampled_token_ids: Optional[torch.Tensor],
        block_size: int,
        num_seqs: int,
        num_queries: int,
    ):
        """
        Update metadata in-place to advance one decode step.
        """

        assert num_seqs > 0
        assert num_queries > 0
        assert model_input.attn_metadata is not None
        assert sampled_token_ids is not None

        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()

        # Update GPU tensors
        ops.advance_step_flashinfer(
            num_seqs=num_seqs,
            num_queries=num_queries,
            block_size=block_size,
            input_tokens=model_input.input_tokens,
            sampled_token_ids=model_input.input_tokens,
            input_positions=model_input.input_positions,
            seq_lens=self.seq_lens_tensor,
            slot_mapping=self.slot_mapping,
            block_tables=self.block_tables,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            block_table_bound=self.block_table_bound)

449

450
451
452
453
454
455
456
457
458
459
460
461
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
        self.slot_mapping: List[int] = []
        self.prefill_seq_lens: List[int] = []
        self.context_lens: List[int] = []
        self.block_tables: List[List[int]] = []
        self.curr_seq_lens: List[int] = []
        self.num_prefills = 0
        self.num_prefill_tokens = 0
        self.num_decode_tokens = 0

462
463
464
        self.input_builder = input_builder
        self.runner = input_builder.runner

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        self.sliding_window = input_builder.sliding_window
        self.block_size = input_builder.block_size
        self.use_v2_block_manager = (
            input_builder.scheduler_config.use_v2_block_manager)

        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        self.paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        self.paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        self.paged_kv_last_page_len: List[int] = []
486
        self.total_blocks = 0
487
488
        self.is_profile_run: bool = False

489
490
491
    def _add_seq_group(
            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
            chunked_prefill_enabled: bool):
492
493
494
495
496
        """Add a sequence group to the metadata. Specifically update/append
        1. context length.
        2. block table.
        3. slot mapping.
        """
497
498
499
        is_prompt = inter_data.is_prompt
        block_tables = inter_data.block_tables
        computed_block_nums = inter_data.computed_block_nums
500
501
502

        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
             curr_sliding_window_block) in zip(
503
504
505
506
                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
                 inter_data.orig_seq_lens, inter_data.seq_lens,
                 inter_data.query_lens, inter_data.context_lens,
                 inter_data.curr_sliding_window_blocks):
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
            self.context_lens.append(context_len)
            if is_prompt:
                self.num_prefills += 1
                self.num_prefill_tokens += token_len
                self.prefill_seq_lens.append(seq_len)
            else:
                assert query_len == 1, (
                    "seq_len: {}, context_len: {}, query_len: {}".format(
                        seq_len, context_len, query_len))
                self.num_decode_tokens += query_len
                self.curr_seq_lens.append(curr_seq_len)

            # Compute block table.
            # TODO(sang): Combine chunked prefill and prefix caching by
            # only allowing multiple of block_size chunk size.
            # NOTE: This only works for oooooooxxx style attention.
            block_table = []
524
            if inter_data.prefix_cache_hit:
525
526
527
528
529
530
531
532
533
534
535
536
537
538
                block_table = computed_block_nums
            elif ((chunked_prefill_enabled or not is_prompt)
                  and block_tables is not None):
                block_table = block_tables[seq_id][-curr_sliding_window_block:]
            self.block_tables.append(block_table)

            is_profile_run = is_block_tables_empty(block_tables)

            # Compute slot mapping.
            start_idx = compute_slot_mapping_start_idx(
                is_prompt, query_len, context_len, self.sliding_window,
                self.use_v2_block_manager)
            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                                 seq_len, context_len, start_idx,
539
                                 self.block_size, inter_data.block_tables)
540
541
542
543
544

            # It is not necessary to add paged_kv_indices, paged_kv_indptr,
            # and paged_kv_last_page_len for profile run because we will
            # create dummy inputs.
            if is_profile_run:
545
                self.is_profile_run = is_profile_run
546
547
548
                return

            block_table = block_tables[seq_id]
549
550
551
552
553
554
555
556
            self._update_paged_kv_tensors(block_table, seq_len)

    def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
        # Get the number of valid blocks based on sequence length.
        # If seq_len = 16, block_size = 16,
        # block_table_bound is 1 with 1 valid block.
        # If seq_len = 15, block_size = 16,
        # block_table_bound is 0 + 1 with 1 valid block.
557
        self.total_blocks += len(block_table)
558
559
560
561
562
563
564
565
566
567
568
        block_table_bound = seq_len // self.block_size + 1 \
                            if seq_len % self.block_size != 0 \
                            else seq_len // self.block_size
        self.paged_kv_indices.extend(block_table[:block_table_bound])
        self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
                                    block_table_bound)

        last_page_len = seq_len % self.block_size
        if last_page_len == 0:
            last_page_len = self.block_size
        self.paged_kv_last_page_len.append(last_page_len)
569

570
    def build(self, seq_lens: List[int], query_lens: List[int],
571
              cuda_graph_pad_size: int, batch_size: int):
572
573
574
575
576
577
578
579
580
        """Build attention metadata with on-device tensors.

        Args:
            seq_lens: The maybe padded sequence lengths of the input sequences.
            query_lens: The query lengths of the input sequences.
            cuda_graph_pad_size: The padding size for cuda graph.
                                 -1 if cuda graph is not used.
            batch_size: The maybe padded batch size.
        """
581
582
583
584
585
        for inter_data in self.input_builder.inter_data_list:
            self._add_seq_group(inter_data,
                                self.input_builder.chunked_prefill_enabled)

        device = self.runner.device
586
587
588
589
590
591
592
593
594
        use_captured_graph = cuda_graph_pad_size != -1

        max_query_len = max(query_lens)
        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
        num_decode_tokens = self.num_decode_tokens

        if use_captured_graph:
            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
            self.block_tables.extend([] * cuda_graph_pad_size)
595
            num_decode_tokens = batch_size
596
597
598

            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
599
            input_block_tables = self.runner.graph_block_tables[:batch_size]
600
            max_blocks = input_block_tables.shape[1]
601
602
            for i, block_table in enumerate(self.block_tables):
                if block_table:
603
604
605
606
607
608
609
610
611
612
                    num_blocks = len(block_table)
                    if num_blocks <= max_blocks:
                        input_block_tables[i, :num_blocks] = block_table
                    else:
                        # It may be possible to have more blocks allocated due
                        # to lookahead slots of multi-step, however, they are
                        # not used anyway, so can be safely ignored.
                        input_block_tables[
                            i, :max_blocks] = block_table[:max_blocks]

613
614
            block_tables = torch.from_numpy(input_block_tables).to(
                device, non_blocking=True)
615
616
617
618
619
620
621
622
623
624
625
626
627
628

            last_paged_kv_indptr = self.paged_kv_indptr[-1]
            self.paged_kv_indptr.extend([last_paged_kv_indptr] *
                                        cuda_graph_pad_size)
            self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
        else:
            block_tables = make_tensor_with_pad(
                self.block_tables,
                pad=0,
                dtype=torch.int,
                device=device,
            )
        assert max_query_len > 0, ("query_lens: {}".format(query_lens))

629
630
631
632
633
634
635
        assert device is not None
        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
                                           self.runner.pin_memory)
        query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
                                             self.runner.pin_memory)
        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
                                               device, self.runner.pin_memory)
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                      dtype=torch.int32,
                                      device=device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=device)
        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=query_start_loc.dtype,
                     out=query_start_loc[1:])

        if len(self.paged_kv_indptr) > 0:
652
653
654
655
            # extend to the maximum number of blocks as returned by the
            # scheduler
            self.paged_kv_indices.extend(
                [0] * (self.total_blocks - len(self.paged_kv_indices)))
656
657
658
659
660
661
662
663
            paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
                                                   device="cpu",
                                                   dtype=torch.int)
            paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
                                                  device="cpu",
                                                  dtype=torch.int)
            paged_kv_last_page_len_tensor = torch.tensor(
                self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
664
665
666
667
            block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
                                                   1,
                                                   device="cpu",
                                                   dtype=torch.int)
668
669
670
671
        else:
            paged_kv_indices_tensor = None
            paged_kv_indptr_tensor = None
            paged_kv_last_page_len_tensor = None
672
            block_table_bound_tensor = None
673

674
675
676
677
678
679
        if self.runner.kv_cache_dtype.startswith("fp8"):
            kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.runner.kv_cache_dtype)
        else:
            kv_cache_dtype = get_kv_cache_torch_dtype(
                self.runner.kv_cache_dtype, self.runner.model_config.dtype)
680

681
682
683
684
685
686
687
688
689
690
        return FlashInferMetadata(
            num_prefills=self.num_prefills,
            slot_mapping=slot_mapping_tensor,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            max_prefill_seq_len=max_prefill_seq_len,
            block_tables=block_tables,
            paged_kv_indptr=paged_kv_indptr_tensor,
            paged_kv_indices=paged_kv_indices_tensor,
            paged_kv_last_page_len=paged_kv_last_page_len_tensor,
691
692
            block_table_bound=block_table_bound_tensor,
            seq_lens_tensor=seq_lens_tensor,
693
694
695
696
697
            num_qo_heads=self.runner.model_config.get_num_attention_heads(
                self.runner.parallel_config),
            num_kv_heads=self.runner.model_config.get_num_kv_heads(
                self.runner.parallel_config),
            head_dim=self.runner.model_config.get_head_size(),
698
699
700
701
702
            page_size=self.block_size,
            seq_start_loc=seq_start_loc,
            query_start_loc=query_start_loc,
            device=device,
            data_type=kv_cache_dtype,
703
            q_data_type=self.runner.model_config.dtype,
704
705
            use_cuda_graph=use_captured_graph,
            is_profile_run=self.is_profile_run)
706
707


708
709
710
711
712
713
714
class FlashInferImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
715
716
717
718
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
719
        blocksparse_params: Optional[Dict[str, Any]] = None,
720
        logits_soft_cap: Optional[float] = None,
721
722
723
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
724
        self.scale = float(scale)
725
        self.num_kv_heads = num_kv_heads
726
727
728
729
730
731
732
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is not None:
            raise ValueError("Sliding window is not supported in FlashInfer.")
        self.sliding_window = (-1, -1)
        self.kv_cache_dtype = kv_cache_dtype
733
        self.logits_soft_cap = logits_soft_cap
734

735
736
737
738
739
740
741
742
743
        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: Optional[torch.Tensor],
744
        attn_metadata: FlashInferMetadata,
745
746
        k_scale: float = 1.0,
        v_scale: float = 1.0,
747
        attn_type: AttentionType = AttentionType.DECODER,
748
    ) -> torch.Tensor:
749
750
        assert k_scale == 1.0 and v_scale == 1.0, (
            "key/v_scale is not supported in FlashInfer.")
751
752
753
754
755
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashInferImpl")
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
        num_tokens, hidden_size = query.shape
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if attn_metadata.num_prefill_tokens > 0:
            assert attn_metadata.num_decode_tokens == 0, (
                "Chunked prefill is not supported with flashinfer yet.")
        if attn_metadata.num_decode_tokens > 0:
            assert attn_metadata.num_prefill_tokens == 0, (
                "Chunked prefill is not supported with flashinfer yet.")
        if kv_cache is not None:
            # Use the same reshape and cache kernel as flash attention.
            ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping.flatten(),
775
                self.kv_cache_dtype,
776
777
                k_scale,
                v_scale,
778
            )
779
780
781
782
783
784
            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if self.kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                    self.kv_cache_dtype)
                kv_cache = kv_cache.view(torch_dtype)
785

786
787
        query = query.contiguous(
        )  # Flashinfer requires query to be contiguous
788
        if prefill_meta := attn_metadata.prefill_metadata:
789
790
791
792
793
            # We will use flash attention for prefill
            # when kv_cache is not provided.
            # This happens when vllm runs the profiling to
            # determine the number of blocks.
            if kv_cache is None:
794
                output = torch.ops.vllm.flash_attn_varlen_func(
795
796
797
798
799
                    q=query,
                    k=key,
                    v=value,
                    cu_seqlens_q=prefill_meta.seq_start_loc,
                    cu_seqlens_k=prefill_meta.seq_start_loc,
800
801
                    max_seqlen_q=prefill_meta.max_prefill_seq_len,
                    max_seqlen_k=prefill_meta.max_prefill_seq_len,
802
803
804
805
806
807
                    softmax_scale=self.scale,
                    causal=True,
                    window_size=self.sliding_window,
                    alibi_slopes=self.alibi_slopes,
                )
            else:
808
809
                assert prefill_meta is not None
                assert prefill_meta.prefill_wrapper is not None
810
811
812
                output = prefill_meta.prefill_wrapper.forward(
                    query,
                    kv_cache,
813
                    logits_soft_cap=self.logits_soft_cap,
814
                    causal=True)
815
816
817
818
819
820
821
        else:
            assert attn_metadata.decode_metadata is not None
            assert attn_metadata.decode_metadata.decode_wrapper is not None
            output = attn_metadata.decode_metadata.decode_wrapper.forward(
                query,
                kv_cache,
                sm_scale=self.scale,
822
823
824
                logits_soft_cap=self.logits_soft_cap,
                k_scale=k_scale,
                v_scale=v_scale)
825
        return output.view(num_tokens, hidden_size)