flashinfer_backend.py 24.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
from __future__ import annotations

"""
Support different attention backends.
Now there are two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""

10
import os
11
from enum import Enum, auto
12
from typing import TYPE_CHECKING, List
13
14

import torch
15
16
import triton
import triton.language as tl
17
18

from sglang.global_config import global_config
19
from sglang.srt.layers.attention import AttentionBackend
20
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
22
23
24
25
from sglang.srt.utils import (
    get_bool_env_var,
    is_flashinfer_available,
    should_use_tensor_core,
)
26
27

if TYPE_CHECKING:
28
    from sglang.srt.layers.radix_attention import RadixAttention
29
30
    from sglang.srt.model_executor.model_runner import ModelRunner

31
if is_flashinfer_available():
32
33
34
35
36
37
38
    from flashinfer import (
        BatchDecodeWithPagedKVCacheWrapper,
        BatchPrefillWithPagedKVCacheWrapper,
        BatchPrefillWithRaggedKVCacheWrapper,
    )
    from flashinfer.cascade import merge_state

39

40
41
42
43
44
class WrapperDispatch(Enum):
    SLIDING_WINDOW = auto()
    CROSS_ATTENTION = auto()


45
46
47
48
49
50
class FlashInferAttnBackend(AttentionBackend):
    """Flashinfer attention kernels."""

    def __init__(self, model_runner: ModelRunner):
        super().__init__()

51
52
53
54
55
56
57
58
        self.decode_use_tensor_cores = should_use_tensor_core(
            kv_cache_dtype=model_runner.kv_cache_dtype,
            num_attention_heads=model_runner.model_config.num_attention_heads
            // model_runner.tp_size,
            num_kv_heads=model_runner.model_config.get_num_kv_heads(
                model_runner.tp_size
            ),
        )
59

60
        self.max_context_len = model_runner.model_config.context_len
61

62
63
        assert not (
            model_runner.sliding_window_size is not None
64
            and model_runner.model_config.is_encoder_decoder
65
66
        ), "Sliding window and cross attention are not supported together"

67
68
        if model_runner.sliding_window_size is not None:
            self.num_wrappers = 2
69
            self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
70
        elif model_runner.model_config.is_encoder_decoder:
71
72
            self.num_wrappers = 2
            self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        else:
            self.num_wrappers = 1
            self.dispatch_reason = None

        # Allocate buffers
        self.workspace_buffer = torch.empty(
            global_config.flashinfer_workspace_size,
            dtype=torch.uint8,
            device=model_runner.device,
        )
        max_bs = model_runner.req_to_token_pool.size
        self.kv_indptr = [
            torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
            for _ in range(self.num_wrappers)
        ]
        self.kv_last_page_len = torch.ones(
            (max_bs,), dtype=torch.int32, device=model_runner.device
        )
        self.qo_indptr = [
            torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
            for _ in range(self.num_wrappers)
        ]
95

96
        # Create wrappers
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        # NOTE: we do not use ragged attention when there are multiple wrappers
        self.prefill_wrapper_ragged = (
            BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
            if self.num_wrappers == 1
            else None
        )

        # Two wrappers: one for sliding window attention and one for full attention.
        # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
        self.prefill_wrappers_paged = []
        self.decode_wrappers = []
        for _ in range(self.num_wrappers):
            self.prefill_wrappers_paged.append(
                BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
            )
            self.decode_wrappers.append(
                BatchDecodeWithPagedKVCacheWrapper(
                    self.workspace_buffer,
                    "NHD",
                    use_tensor_cores=self.decode_use_tensor_cores,
117
                )
118
            )
119

120
121
122
123
124
125
126
        # Create indices updater
        self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
        self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
            model_runner, self
        )

        # Other metadata
127
128
129
        self.forward_metadata = None
        self.cuda_graph_metadata = {}

130
131
    def init_forward_metadata(self, forward_batch: ForwardBatch):
        if forward_batch.forward_mode.is_decode():
132
133
134
            self.indices_updater_decode.update(
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
135
                forward_batch.seq_lens_sum,
136
137
                decode_wrappers=None,
                encoder_lens=forward_batch.encoder_lens,
138
139
            )
            self.forward_metadata = (self.decode_wrappers,)
140
        else:
141
            prefix_lens = forward_batch.extend_prefix_lens
142
143

            # Some heuristics to check whether to use ragged forward
144
            if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
145
                use_ragged = True
146
147
148
149
                extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
            else:
                use_ragged = False
                extend_no_prefix = False
150

151
152
153
            self.indices_updater_prefill.update(
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
154
                forward_batch.seq_lens_sum,
155
                prefix_lens,
156
157
                use_ragged=use_ragged,
                encoder_lens=forward_batch.encoder_lens,
158
            )
159

160
            self.forward_metadata = (use_ragged, extend_no_prefix)
161
162

    def init_cuda_graph_state(self, max_bs: int):
163
164
        cuda_graph_kv_indices = torch.zeros(
            (max_bs * self.max_context_len,),
165
166
167
            dtype=torch.int32,
            device="cuda",
        )
168
169
        self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
            cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
170
        ]
171

172
    def init_forward_metadata_capture_cuda_graph(
173
174
175
176
177
        self,
        bs: int,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        encoder_lens: torch.Tensor = None,
178
    ):
179
180
181
182
183
184
185
186
        decode_wrappers = []
        for i in range(self.num_wrappers):
            decode_wrappers.append(
                BatchDecodeWithPagedKVCacheWrapper(
                    self.workspace_buffer,
                    "NHD",
                    use_cuda_graph=True,
                    use_tensor_cores=self.decode_use_tensor_cores,
187
                    paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
188
                    paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
189
                    paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
190
                )
191
            )
192

193
194
        seq_lens_sum = seq_lens.sum().item()
        self.indices_updater_decode.update(
195
196
197
198
199
            req_pool_indices,
            seq_lens,
            seq_lens_sum,
            decode_wrappers=decode_wrappers,
            encoder_lens=encoder_lens,
200
        )
201
        self.cuda_graph_metadata[bs] = decode_wrappers
202
        self.forward_metadata = (decode_wrappers,)
203

204
    def init_forward_metadata_replay_cuda_graph(
205
206
207
208
209
        self,
        bs: int,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
210
        encoder_lens: torch.Tensor = None,
211
    ):
212
        self.indices_updater_decode.update(
213
214
215
            req_pool_indices[:bs],
            seq_lens[:bs],
            seq_lens_sum,
216
217
            decode_wrappers=self.cuda_graph_metadata[bs],
            encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
218
219
        )

220
221
222
    def get_cuda_graph_seq_len_fill_value(self):
        return 0

223
    def forward_extend(
224
225
226
227
228
229
230
        self,
        q,
        k,
        v,
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
231
    ):
232
233
234
        prefill_wrapper_paged = self.prefill_wrappers_paged[
            self._get_wrapper_idx(layer)
        ]
235

236
        use_ragged, extend_no_prefix = self.forward_metadata
237
238
239
240
241
        cache_loc = (
            forward_batch.out_cache_loc
            if not layer.is_cross_attention
            else forward_batch.encoder_out_cache_loc
        )
242
243
244
245

        if not use_ragged:
            if k is not None:
                assert v is not None
246
247
                if save_kv_cache:
                    forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
248

249
250
            o = prefill_wrapper_paged.forward(
                q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
251
                forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
252
                causal=not layer.is_cross_attention,
253
254
255
256
257
258
259
260
261
262
263
264
265
266
                sm_scale=layer.scaling,
                window_left=layer.sliding_window_size,
                logits_soft_cap=layer.logit_cap,
            )
        else:
            o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
                q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
                k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
                v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
                causal=True,
                sm_scale=layer.scaling,
                logits_soft_cap=layer.logit_cap,
            )

267
            if extend_no_prefix:
268
269
270
271
                o = o1
            else:
                o2, s2 = prefill_wrapper_paged.forward_return_lse(
                    q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
272
                    forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
273
274
275
276
277
278
279
                    causal=False,
                    sm_scale=layer.scaling,
                    logits_soft_cap=layer.logit_cap,
                )

                o, _ = merge_state(o1, s1, o2, s2)

280
281
            if save_kv_cache:
                forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
282
283
284

        return o.view(-1, layer.tp_q_head_num * layer.head_dim)

285
    def forward_decode(
286
287
288
289
290
291
292
        self,
        q,
        k,
        v,
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
293
    ):
294
        decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
295
296
297
298
299
        cache_loc = (
            forward_batch.out_cache_loc
            if not layer.is_cross_attention
            else forward_batch.encoder_out_cache_loc
        )
300
301
302

        if k is not None:
            assert v is not None
303
304
            if save_kv_cache:
                forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
305
306
307

        o = decode_wrapper.forward(
            q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
308
            forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
309
310
311
312
313
            sm_scale=layer.scaling,
            logits_soft_cap=layer.logit_cap,
        )

        return o.view(-1, layer.tp_q_head_num * layer.head_dim)
314

315
    def _get_wrapper_idx(self, layer: RadixAttention):
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        if self.num_wrappers == 1:
            return 0

        if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
            return layer.sliding_window_size == -1
        if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
            return layer.is_cross_attention

        raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")


class FlashInferIndicesUpdaterDecode:
    def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
        # Constants
        self.num_qo_heads = (
            model_runner.model_config.num_attention_heads // model_runner.tp_size
        )
        self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
            model_runner.tp_size
        )
        self.head_dim = model_runner.model_config.head_dim
        self.data_type = model_runner.kv_cache_dtype
        self.q_data_type = model_runner.dtype
        self.sliding_window_size = model_runner.sliding_window_size

341
342
        self.attn_backend = attn_backend

343
344
345
346
347
348
349
        # Buffers and wrappers
        self.kv_indptr = attn_backend.kv_indptr
        self.kv_last_page_len = attn_backend.kv_last_page_len
        self.req_to_token = model_runner.req_to_token_pool.req_to_token
        self.decode_wrappers = attn_backend.decode_wrappers

        # Dispatch
350
        if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
351
            self.update = self.update_sliding_window
352
        elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
353
354
            self.update = self.update_cross_attention
        else:
355
            assert self.attn_backend.num_wrappers == 1
356
357
            self.update = self.update_single_wrapper

358
    def update(
359
360
361
362
363
364
        self,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
        decode_wrappers: List,
        encoder_lens: torch.Tensor,
365
    ):
366
        # Keep the signature for type checking. It will be assigned during runtime.
367
368
        raise NotImplementedError()

369
370
371
372
373
    def update_single_wrapper(
        self,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
374
375
        decode_wrappers: List,
        encoder_lens: torch.Tensor,
376
    ):
377
378
        decode_wrappers = decode_wrappers or self.decode_wrappers
        self.call_begin_forward(
379
380
381
382
383
384
            decode_wrappers[0],
            req_pool_indices,
            seq_lens,
            seq_lens_sum,
            self.kv_indptr[0],
            None,
385
386
        )

387
388
389
390
391
    def update_sliding_window(
        self,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
392
393
        decode_wrappers: List,
        encoder_lens: torch.Tensor,
394
    ):
395
396
397
398
399
        decode_wrappers = decode_wrappers or self.decode_wrappers

        for wrapper_id in range(2):
            if wrapper_id == 0:
                # Sliding window attention
400
                paged_kernel_lens_tmp = torch.minimum(  # TODO: replace this with clamp
401
402
403
                    seq_lens,
                    torch.tensor(self.sliding_window_size + 1),
                )
404
405
                paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
                kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
406
407
            else:
                # Full attention
408
409
410
                paged_kernel_lens_tmp = seq_lens
                paged_kernel_lens_sum_tmp = seq_lens_sum
                kv_start_idx_tmp = None
411
412
413
414

            self.call_begin_forward(
                decode_wrappers[wrapper_id],
                req_pool_indices,
415
416
                paged_kernel_lens_tmp,
                paged_kernel_lens_sum_tmp,
417
                self.kv_indptr[wrapper_id],
418
                kv_start_idx_tmp,
419
420
            )

421
422
    def update_cross_attention(
        self,
423
424
425
426
427
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
        decode_wrappers: List,
        encoder_lens: torch.Tensor,
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    ):
        decode_wrappers = decode_wrappers or self.decode_wrappers

        for wrapper_id in range(2):
            if wrapper_id == 0:
                # Normal attention
                paged_kernel_lens = seq_lens
                kv_start_idx = encoder_lens
            else:
                # Cross attention
                paged_kernel_lens = encoder_lens
                kv_start_idx = torch.zeros_like(encoder_lens)
                seq_lens_sum = encoder_lens.sum().item()

            self.call_begin_forward(
                decode_wrappers[wrapper_id],
                req_pool_indices,
                paged_kernel_lens,
                seq_lens_sum,
                self.kv_indptr[wrapper_id],
                kv_start_idx,
            )
450
451

    def call_begin_forward(
452
453
        self,
        wrapper,
454
455
456
457
458
        req_pool_indices: torch.Tensor,
        paged_kernel_lens: torch.Tensor,
        paged_kernel_lens_sum: int,
        kv_indptr: torch.Tensor,
        kv_start_idx: torch.Tensor,
459
460
    ):
        bs = len(req_pool_indices)
461
        kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
462
        kv_indptr = kv_indptr[: bs + 1]
463
464
465
        kv_indices = torch.empty(
            paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
        )
466
467
468
469
470
471
472
473

        create_flashinfer_kv_indices_triton[(bs,)](
            self.req_to_token,
            req_pool_indices,
            paged_kernel_lens,
            kv_indptr,
            kv_start_idx,
            kv_indices,
474
            self.req_to_token.shape[1],
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
        )

        wrapper.end_forward()
        wrapper.begin_forward(
            kv_indptr,
            kv_indices,
            self.kv_last_page_len[:bs],
            self.num_qo_heads,
            self.num_kv_heads,
            self.head_dim,
            1,
            data_type=self.data_type,
            q_data_type=self.q_data_type,
        )


class FlashInferIndicesUpdaterPrefill:
    def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
        # Constants
        self.num_qo_heads = (
            model_runner.model_config.num_attention_heads // model_runner.tp_size
        )
        self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
            model_runner.tp_size
        )
        self.head_dim = model_runner.model_config.head_dim
        self.data_type = model_runner.kv_cache_dtype
        self.q_data_type = model_runner.dtype
        self.sliding_window_size = model_runner.sliding_window_size

505
506
        self.attn_backend = attn_backend

507
508
509
510
511
512
513
514
515
        # Buffers and wrappers
        self.kv_indptr = attn_backend.kv_indptr
        self.kv_last_page_len = attn_backend.kv_last_page_len
        self.qo_indptr = attn_backend.qo_indptr
        self.req_to_token = model_runner.req_to_token_pool.req_to_token
        self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
        self.wrappers_paged = attn_backend.prefill_wrappers_paged

        # Dispatch
516
        if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
517
            self.update = self.update_sliding_window
518
        elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
519
520
            self.update = self.update_cross_attention
        else:
521
            assert self.attn_backend.num_wrappers == 1
522
523
            self.update = self.update_single_wrapper

524
525
526
527
528
529
530
531
532
    def update(
        self,
        req_pool_indices: torch.Tnesor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
        prefix_lens: torch.Tensor,
        use_ragged: bool,
        encoder_lens: torch.Tensor,
    ):
533
        # Keep the signature for type checking. It will be assigned during runtime.
534
535
        raise NotImplementedError()

536
    def update_single_wrapper(
537
538
539
540
541
542
543
        self,
        req_pool_indices: torch.Tnesor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
        prefix_lens: torch.Tensor,
        use_ragged: bool,
        encoder_lens: torch.Tensor,
544
545
546
    ):
        if use_ragged:
            paged_kernel_lens = prefix_lens
547
            paged_kernel_lens_sum = paged_kernel_lens.sum().item()
548
549
        else:
            paged_kernel_lens = seq_lens
550
            paged_kernel_lens_sum = seq_lens_sum
551
552
553
554
555
556

        self.call_begin_forward(
            self.wrapper_ragged,
            self.wrappers_paged[0],
            req_pool_indices,
            paged_kernel_lens,
557
            paged_kernel_lens_sum,
558
559
560
561
562
563
564
565
566
            seq_lens,
            prefix_lens,
            None,
            self.kv_indptr[0],
            self.qo_indptr[0],
            use_ragged,
        )

    def update_sliding_window(
567
568
569
570
571
572
573
        self,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
        prefix_lens: torch.Tensor,
        use_ragged: bool,
        encoder_lens: torch.Tensor,
574
575
576
577
578
579
580
581
    ):
        for wrapper_id in range(2):
            if wrapper_id == 0:
                # window attention use paged only
                paged_kernel_lens = torch.minimum(
                    seq_lens,
                    torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
                )
582
                paged_kernel_lens_sum = paged_kernel_lens.sum().item()
583
584
585
            else:
                # full attention
                paged_kernel_lens = seq_lens
586
587
                paged_kernel_lens_sum = seq_lens_sum

588
589
590
591
592
593
594
            kv_start_idx = seq_lens - paged_kernel_lens

            self.call_begin_forward(
                self.wrapper_ragged,
                self.wrappers_paged[wrapper_id],
                req_pool_indices,
                paged_kernel_lens,
595
                paged_kernel_lens_sum,
596
597
598
599
600
601
602
603
                seq_lens,
                prefix_lens,
                kv_start_idx,
                self.kv_indptr[wrapper_id],
                self.qo_indptr[wrapper_id],
                use_ragged,
            )

604
    def update_cross_attention(
605
606
607
608
609
610
611
        self,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
        prefix_lens: torch.Tensor,
        use_ragged: bool,
        encoder_lens: torch.Tensor,
612
613
614
615
616
617
    ):
        for wrapper_id in range(2):
            if wrapper_id == 0:
                # normal attention
                paged_kernel_lens = seq_lens
                kv_start_idx = encoder_lens
618
                paged_kernel_lens_sum = seq_lens_sum
619
620
621
622
            else:
                # cross attention
                paged_kernel_lens = encoder_lens
                kv_start_idx = torch.zeros_like(encoder_lens)
623
                paged_kernel_lens_sum = paged_kernel_lens.sum().item()
624
625
626
627
628
629

            self.call_begin_forward(
                self.wrapper_ragged,
                self.wrappers_paged[wrapper_id],
                req_pool_indices,
                paged_kernel_lens,
630
                paged_kernel_lens_sum,
631
632
633
634
635
636
637
                seq_lens,
                prefix_lens,
                kv_start_idx,
                self.kv_indptr[wrapper_id],
                self.qo_indptr[wrapper_id],
                use_ragged,
            )
638
639
640
641
642

    def call_begin_forward(
        self,
        wrapper_ragged,
        wrapper_paged,
643
644
645
646
647
648
649
650
651
        req_pool_indices: torch.Tensor,
        paged_kernel_lens: torch.Tensor,
        paged_kernel_lens_sum: int,
        seq_lens: torch.Tensor,
        prefix_lens: torch.Tensor,
        kv_start_idx: torch.Tensor,
        kv_indptr: torch.Tensor,
        qo_indptr: torch.Tensor,
        use_ragged: bool,
652
653
    ):
        bs = len(req_pool_indices)
654
        kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
655
        kv_indptr = kv_indptr[: bs + 1]
656
657
658
        kv_indices = torch.empty(
            paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
        )
659
660
661
662
663
664
665
        create_flashinfer_kv_indices_triton[(bs,)](
            self.req_to_token,
            req_pool_indices,
            paged_kernel_lens,
            kv_indptr,
            kv_start_idx,
            kv_indices,
666
            self.req_to_token.shape[1],
667
668
        )

669
        qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        qo_indptr = qo_indptr[: bs + 1]

        # extend part
        if use_ragged:
            wrapper_ragged.end_forward()
            wrapper_ragged.begin_forward(
                qo_indptr,
                qo_indptr,
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
            )

        # cached part
        wrapper_paged.end_forward()
        wrapper_paged.begin_forward(
            qo_indptr,
            kv_indptr,
            kv_indices,
            self.kv_last_page_len[:bs],
            self.num_qo_heads,
            self.num_kv_heads,
            self.head_dim,
            1,
        )


@triton.jit
def create_flashinfer_kv_indices_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices_ptr,
    page_kernel_lens_ptr,
    kv_indptr,
    kv_start_idx,
    kv_indices_ptr,
705
    req_to_token_ptr_stride: tl.constexpr,
706
707
708
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(axis=0)
709

710
711
712
713
714
715
716
717
718
719
720
    req_pool_index = tl.load(req_pool_indices_ptr + pid)
    kv_indices_offset = tl.load(kv_indptr + pid)

    kv_start = 0
    kv_end = 0
    if kv_start_idx:
        kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
        kv_end = kv_start
    kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)

    num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
721
722
723
724
725
726
727
728
729
730
731
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < kv_end - kv_start
        data = tl.load(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + kv_start
            + offset,
            mask=mask,
        )
        tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)