triton_backend.py 36.1 KB
Newer Older
1
2
from __future__ import annotations

3
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Optional, Union
5
6

import torch
7
import triton
8
import triton.language as tl
9

10
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
from sglang.srt.layers.dp_attention import get_attention_tp_size
13
from sglang.srt.layers.radix_attention import AttentionType
14
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
16
17

if TYPE_CHECKING:
18
    from sglang.srt.layers.radix_attention import RadixAttention
19
    from sglang.srt.model_executor.model_runner import ModelRunner
20
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
21
22


23
24
25
26
27
28
29
30
31
32
33
@dataclass
class ForwardMetadata:
    attn_logits: torch.Tensor
    attn_lse: torch.Tensor
    max_extend_len: int
    num_kv_splits: torch.Tensor
    kv_indptr: torch.Tensor
    kv_indices: torch.Tensor
    qo_indptr: torch.Tensor
    custom_mask: torch.Tensor
    mask_indptr: torch.Tensor
34
35
36
37
    # Sliding window
    window_kv_indptr: torch.Tensor
    window_kv_indices: torch.Tensor
    window_num_kv_splits: torch.Tensor
38
39


40
class TritonAttnBackend(AttentionBackend):
41
42
43
44
45
46
    def __init__(
        self,
        model_runner: ModelRunner,
        skip_prefill: bool = False,
        kv_indptr_buf: Optional[torch.Tensor] = None,
    ):
47
        # Lazy import to avoid the initialization of cuda context
48
49
50
51
52
53
54
55
56
        from sglang.srt.layers.attention.triton_ops.decode_attention import (
            decode_attention_fwd,
        )
        from sglang.srt.layers.attention.triton_ops.extend_attention import (
            extend_attention_fwd,
        )

        super().__init__()

57
58
        self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
        self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
Ke Bao's avatar
Ke Bao committed
59

60
61
        self.skip_prefill = skip_prefill

62
        max_bs = model_runner.req_to_token_pool.size
63

64
65
66
67
68
69
70
        assert not (
            model_runner.sliding_window_size is not None
            and model_runner.model_config.is_encoder_decoder
        ), "Sliding window and cross attention are not supported together"
        self.sliding_window_size = model_runner.sliding_window_size

        # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
71
72
73
74
75
76
77
        if kv_indptr_buf is None:
            self.kv_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )
        else:
            self.kv_indptr = kv_indptr_buf

78
79
80
81
82
83
84
85
86
87
88
89
        # If sliding window is enabled, we might need two sets of buffers
        # because of interleaved attention types (e.g. for Gemma3)
        self.window_kv_indptr = None
        if self.sliding_window_size is not None and self.sliding_window_size > 0:
            if kv_indptr_buf is None:
                self.window_kv_indptr = torch.zeros(
                    (max_bs + 1,), dtype=torch.int32, device=model_runner.device
                )
            else:
                # When provided a buffer, create a clone for the second buffer
                self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)

90
91
        self.req_to_token = model_runner.req_to_token_pool.req_to_token

92
93
94
95
96
97
98
99
        if not self.skip_prefill:
            self.qo_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )

            self.mask_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int64, device=model_runner.device
            )
100
101

        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
102
        self.speculative_num_steps = model_runner.server_args.speculative_num_steps
103

104
105
106
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
107
108
109
        self.num_kv_head = model_runner.model_config.get_num_kv_heads(
            get_attention_tp_size()
        )
110

111
112
113
114
        self.static_kv_splits = get_bool_env_var(
            "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
        )
        self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
115
116
        self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]

117
        self.forward_metadata: ForwardMetadata = None
118

119
        self.max_context_len = model_runner.model_config.context_len
120

121
        self.device = model_runner.device
122
123
124
125
126
127
128
        self.device_core_count = get_device_core_count(model_runner.gpu_id)

    def get_num_kv_splits(
        self,
        num_kv_splits: torch.Tensor,
        seq_lens: torch.Tensor,
    ):
129
        num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
130
131
132
        # NOTE(alcanderian): Considering speculative_decodeing,
        # num_kv_splits.shape[0] will be topk * real_num_token.
        # And the real_num_token is num_seq in decoding phase.
133
134
135
136
137
138
139
        num_group = num_token // num_seq

        assert (
            num_group * num_seq == num_token
        ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"

        if self.static_kv_splits or self.device_core_count <= 0:
140
141
142
            num_kv_splits.fill_(self.max_kv_splits)
            return

143
144
145
146
147
        if num_seq < 256:
            SCHEDULE_SEQ = 256
        else:
            SCHEDULE_SEQ = triton.next_power_of_2(num_seq)

148
149
150
        get_num_kv_splits_triton[(1,)](
            num_kv_splits,
            seq_lens,
151
152
            num_seq,
            num_group,
153
            self.num_head,
154
            self.num_kv_head,
155
156
            self.max_kv_splits,
            self.device_core_count,
157
            MAX_NUM_SEQ=SCHEDULE_SEQ,
158
        )
159

160
161
162
    def init_forward_metadata(self, forward_batch: ForwardBatch):
        """Init auxiliary variables for triton attention backend."""

163
164
        bs = forward_batch.batch_size
        kv_indptr = self.kv_indptr
165
166
167
        window_kv_indptr = self.window_kv_indptr
        window_kv_indices = None
        window_num_kv_splits = None
168
169
170
171
172
173
        spec_info = forward_batch.spec_info

        if forward_batch.forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
174
                kv_indices = torch.empty(
175
176
177
178
179
180
181
182
183
184
185
                    forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
                )
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                # Sliding window
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
                    window_kv_indptr, window_kv_indices, window_kv_lens = (
                        update_sliding_window_buffer(
                            self.window_kv_indptr,
                            self.req_to_token,
                            self.sliding_window_size,
                            forward_batch.seq_lens,
                            forward_batch.req_pool_indices,
                            bs,
                            self.device,
                        )
                    )
                    window_num_kv_splits = torch.empty(
                        (bs,), dtype=torch.int32, device=self.device
                    )
                    self.get_num_kv_splits(window_num_kv_splits, window_kv_lens)
206
207
208
209
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
                bs = kv_indptr.shape[0] - 1

210
211
212
213
214
215
216
217
218
219
            attn_logits = torch.empty(
                (bs, self.num_head, self.max_kv_splits, self.v_head_dim),
                dtype=torch.float32,
                device=self.device,
            )
            attn_lse = torch.empty(
                (bs, self.num_head, self.max_kv_splits),
                dtype=torch.float32,
                device=self.device,
            )
220
            num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
221
            self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
222

223
224
225
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
226
            max_extend_len = None
227
        elif forward_batch.forward_mode.is_target_verify():
228
            # TODO: Support sliding window in spec inference
229
230
231
232
233
234
235
236
237
            bs = len(forward_batch.req_pool_indices)
            qo_indptr = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            # Different with flashinfer kv_indptr and kv_indices construction
238
239
            kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
            kv_indptr = kv_indptr[: bs + 1]
240
            kv_indices = torch.empty(
241
                kv_indptr[-1], dtype=torch.int32, device=self.device
242
243
            )
            create_flashinfer_kv_indices_triton[(bs,)](
244
                self.req_to_token,
245
246
247
248
249
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                kv_indptr,
                None,
                kv_indices,
250
                self.req_to_token.stride(0),
251
252
            )

253
254
255
256
257
258
259
260
            custom_mask = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (
                forward_batch.seq_lens + self.num_draft_tokens
            )
            mask_indptr = self.mask_indptr
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
            mask_indptr = mask_indptr[: bs + 1]
            max_extend_len = self.num_draft_tokens
261
            num_kv_splits = None
262
            attn_logits = None
263
            attn_lse = None
264

265
266
267
268
269
        elif forward_batch.forward_mode.is_draft_extend():
            kv_indices, kv_indptr, qo_indptr, custom_mask = (
                spec_info.generate_attn_arg_prefill(
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
270
                    None,
271
272
273
274
                    self.req_to_token,
                )
            )
            mask_indptr = None
275
276
277
            # TODO(FIXME): This will trigger an invalid Eagle tree when using
            # `max(spec_info.accept_length_cpu)`.
            # It might have been forgotten to update somewhere.
278
            max_extend_len = torch.max(spec_info.accept_length).item()
279
            num_kv_splits = None
280
            attn_logits = None
281
            attn_lse = None
282
        else:
283
284
285
286
            kv_indptr[1 : bs + 1] = torch.cumsum(
                forward_batch.extend_prefix_lens, dim=0
            )
            kv_indptr = kv_indptr[: bs + 1]
287
            kv_indices = torch.empty(
288
289
290
291
292
293
294
295
296
297
298
299
300
                forward_batch.extend_prefix_lens.sum().item(),
                dtype=torch.int32,
                device=self.device,
            )
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                forward_batch.req_pool_indices,
                forward_batch.extend_prefix_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
301
302
303
304
305
306
307
308
309
310
311
            # Sliding window
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
                window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer(
                    self.window_kv_indptr,
                    self.req_to_token,
                    self.sliding_window_size,
                    forward_batch.extend_prefix_lens,
                    forward_batch.req_pool_indices,
                    bs,
                    self.device,
                )
312
313
314
315
316

            qo_indptr = self.qo_indptr
            qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
            qo_indptr = qo_indptr[: bs + 1]
            custom_mask = None
317
            mask_indptr = None
318
            attn_logits = None
319
            attn_lse = None
320
            max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
321
            num_kv_splits = None
322

323
        self.forward_metadata = ForwardMetadata(
324
            attn_logits,
325
            attn_lse,
326
            max_extend_len,
327
            num_kv_splits,
328
329
330
331
            kv_indptr,
            kv_indices,
            qo_indptr,
            custom_mask,
332
            mask_indptr,
333
334
335
            window_kv_indptr,
            window_kv_indices,
            window_num_kv_splits,
336
        )
337

338
    def init_cuda_graph_state(
339
340
341
342
        self,
        max_bs: int,
        max_num_tokens: int,
        kv_indices_buf: Optional[torch.Tensor] = None,
343
    ):
344
        self.cuda_graph_attn_logits = torch.zeros(
345
            (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
346
347
348
349
            dtype=torch.float32,
            device=self.device,
        )
        self.cuda_graph_attn_lse = torch.zeros(
350
            (max_num_tokens, self.num_head, self.max_kv_splits),
351
352
353
            dtype=torch.float32,
            device=self.device,
        )
354
        self.cuda_graph_num_kv_splits = torch.full(
355
            (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
356
        )
357
358
        if kv_indices_buf is None:
            self.cuda_graph_kv_indices = torch.zeros(
359
                (max_num_tokens * self.max_context_len),
360
361
362
363
364
365
366
367
                dtype=torch.int32,
                device=self.device,
            )
        else:
            self.cuda_graph_kv_indices = kv_indices_buf

        if not self.skip_prefill:
            self.cuda_graph_custom_mask = torch.zeros(
368
                (max_num_tokens * self.max_context_len),
369
370
371
                dtype=torch.uint8,
                device=self.device,
            )
372

373
374
375
        if self.sliding_window_size is not None and self.sliding_window_size > 0:
            if kv_indices_buf is None:
                self.cuda_graph_window_kv_indices = torch.zeros(
376
                    (max_num_tokens * self.sliding_window_size),
377
378
379
380
381
382
383
                    dtype=torch.int32,
                    device=self.device,
                )
            else:
                self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)

            self.cuda_graph_window_num_kv_splits = torch.full(
384
385
386
387
                (max_num_tokens,),
                self.max_kv_splits,
                dtype=torch.int32,
                device=self.device,
388
389
            )

390
    def init_forward_metadata_capture_cuda_graph(
391
392
        self,
        bs: int,
393
        num_tokens: int,
394
395
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
396
        encoder_lens: Optional[torch.Tensor],
397
        forward_mode: ForwardMode,
398
        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
399
    ):
400
        assert encoder_lens is None, "Not supported"
401
402
403
        window_kv_indptr = self.window_kv_indptr
        window_kv_indices = None
        window_num_kv_splits = None
404

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        if forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr = self.kv_indptr
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                kv_indices = self.cuda_graph_kv_indices
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices,
                    seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
                    window_kv_indices = self.cuda_graph_window_kv_indices
                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits
                    window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
                        seq_lens[:bs],
                        req_pool_indices,
                        bs,
                    )
435
436
437
438
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices

            attn_logits = self.cuda_graph_attn_logits
439
            attn_lse = self.cuda_graph_attn_lse
440
            max_extend_len = None
441
            num_kv_splits = self.cuda_graph_num_kv_splits
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
        elif forward_mode.is_target_verify():
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )

            custom_mask = self.cuda_graph_custom_mask
468
            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
469
470
471
472
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
            max_extend_len = self.num_draft_tokens
473
            num_kv_splits = None
474
            attn_logits = None
475
            attn_lse = None
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        elif forward_mode.is_draft_extend():
            num_tokens_per_bs = self.speculative_num_steps + 1
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                bs * num_tokens_per_bs + 1,
                step=num_tokens_per_bs,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
            custom_mask = None
            mask_indptr = None
            max_extend_len = num_tokens_per_bs
            num_kv_splits = None
            attn_logits = None
            attn_lse = None
504
505
506
507
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
            )
508

509
        self.forward_metadata = ForwardMetadata(
510
            attn_logits,
511
            attn_lse,
512
            max_extend_len,
513
            num_kv_splits,
514
515
            kv_indptr,
            kv_indices,
516
517
518
            qo_indptr,
            custom_mask,
            mask_indptr,
519
520
521
            window_kv_indptr,
            window_kv_indices,
            window_num_kv_splits,
522
523
524
        )

    def init_forward_metadata_replay_cuda_graph(
525
526
527
528
529
        self,
        bs: int,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
530
        encoder_lens: Optional[torch.Tensor],
531
        forward_mode: ForwardMode,
532
        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
533
        seq_lens_cpu: Optional[torch.Tensor],
534
    ):
535
        # NOTE: encoder_lens expected to be zeros or None
536
537
538
539
        if forward_mode.is_decode_or_idle():
            # Update kv_indptr, kv_indices
            kv_indptr = self.kv_indptr
            kv_indices = self.cuda_graph_kv_indices
540
            num_kv_splits = self.cuda_graph_num_kv_splits
541
542
543
544
545
546
547
548
549
550
551
552
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices[:bs],
                    seq_lens[:bs],
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
553
                num_token = bs
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits
                    window_kv_indices = self.cuda_graph_window_kv_indices
                    _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
                        seq_lens[:bs],
                        req_pool_indices[:bs],
                        bs,
                    )
                    self.get_num_kv_splits(
                        window_num_kv_splits[:num_token], window_kv_lens[:bs]
                    )

573
574
575
            else:
                kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
                kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
576
577
                num_token = spec_info.kv_indptr.shape[0] - 1
            self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
578

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        elif forward_mode.is_target_verify():
            # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
            bs = len(req_pool_indices)
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
            custom_mask = self.cuda_graph_custom_mask
            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
        elif forward_mode.is_draft_extend():
            seq_lens = seq_lens[:bs]
            accept_lens = spec_info.accept_length[:bs]
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
624
625
626
627
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
            )
628

629
630
631
    def get_cuda_graph_seq_len_fill_value(self):
        return 1

632
    def forward_extend(
633
        self,
634
635
636
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
637
638
639
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
640
    ):
641
642
643
644
645
646
        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

647
648
649
650
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
651

652
653
654
655
        causal = True
        if layer.attn_type == AttentionType.ENCODER_ONLY:
            causal = False

656
657
658
659
660
661
662
663
664
665
666
        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
            sliding_window_size = (
                layer.sliding_window_size
            )  # Needed for sliding window mask
            kv_indptr = self.forward_metadata.window_kv_indptr
            kv_indices = self.forward_metadata.window_kv_indices
        else:
            sliding_window_size = -1
            kv_indptr = self.forward_metadata.kv_indptr
            kv_indices = self.forward_metadata.kv_indices

667
668
669
670
671
672
673
        self.extend_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            k.contiguous(),
            v.contiguous(),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
674
            self.forward_metadata.qo_indptr,
675
676
            kv_indptr,
            kv_indices,
677
            self.forward_metadata.custom_mask,
678
            causal,
679
680
            self.forward_metadata.mask_indptr,
            self.forward_metadata.max_extend_len,
681
682
            layer.scaling,
            layer.logit_cap,
683
            sliding_window_size,
684
685
686
        )
        return o

687
    def forward_decode(
688
        self,
689
690
691
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
692
693
694
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
695
    ):
696
697
698
699
700
701
702
703
704
705
        # During torch.compile, there is a bug in rotary_emb that causes the
        # output value to have a 3D tensor shape. This reshapes the output correctly.
        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

706
707
708
709
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
710

711
712
713
714
715
716
717
        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
            kv_indptr = self.forward_metadata.window_kv_indptr
            kv_indices = self.forward_metadata.window_kv_indices
        else:
            kv_indptr = self.forward_metadata.kv_indptr
            kv_indices = self.forward_metadata.kv_indices

718
719
720
721
722
        self.decode_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
723
724
            kv_indptr,
            kv_indices,
725
726
727
            self.forward_metadata.attn_logits,
            self.forward_metadata.attn_lse,
            self.forward_metadata.num_kv_splits,
728
            self.max_kv_splits,
729
730
731
732
            layer.scaling,
            layer.logit_cap,
        )
        return o
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751


class TritonMultiStepDraftBackend:
    """
    Wrap multiple triton attention backends as one for multiple consecutive
    draft decoding steps.
    """

    def __init__(
        self,
        model_runner: ModelRunner,
        topk: int,
        speculative_num_steps: int,
    ):
        from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices

        self.topk = topk
        self.speculative_num_steps = speculative_num_steps
        self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
752
        max_bs = model_runner.req_to_token_pool.size * self.topk
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
        self.kv_indptr = torch.zeros(
            (
                self.speculative_num_steps,
                max_bs + 1,
            ),
            dtype=torch.int32,
            device=model_runner.device,
        )
        self.attn_backends = []
        for i in range(self.speculative_num_steps):
            self.attn_backends.append(
                TritonAttnBackend(
                    model_runner,
                    skip_prefill=True,
                    kv_indptr_buf=self.kv_indptr[i],
                )
            )
        self.max_context_len = self.attn_backends[0].max_context_len
771
772
773
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
774
        self.device = model_runner.device
775
776
        # Cached variables for generate_draft_decode_kv_indices
        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
777
        self.page_size = model_runner.server_args.page_size
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797

    def common_template(
        self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
    ):
        num_seqs = forward_batch.batch_size
        bs = self.topk * num_seqs
        seq_lens_sum = forward_batch.seq_lens_sum

        self.generate_draft_decode_kv_indices[
            (self.speculative_num_steps, num_seqs, self.topk)
        ](
            forward_batch.req_pool_indices,
            forward_batch.req_to_token_pool.req_to_token,
            forward_batch.seq_lens,
            kv_indices_buffer,
            self.kv_indptr,
            forward_batch.positions,
            self.pool_len,
            kv_indices_buffer.shape[1],
            self.kv_indptr.shape[1],
798
799
800
            next_power_of_2(num_seqs),
            next_power_of_2(self.speculative_num_steps),
            next_power_of_2(bs),
801
            self.page_size,
802
803
804
805
806
807
808
809
810
811
        )

        for i in range(self.speculative_num_steps):
            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
                : seq_lens_sum * self.topk + bs * (i + 1)
            ]
            call_fn(i, forward_batch)

    def init_forward_metadata(self, forward_batch: ForwardBatch):
812
        kv_indices = torch.empty(
813
814
815
816
817
            (
                self.speculative_num_steps,
                forward_batch.batch_size * self.topk * self.max_context_len,
            ),
            dtype=torch.int32,
818
            device=self.device,
819
820
821
822
823
824
825
826
827
828
829
830
831
        )

        def call_fn(i, forward_batch):
            forward_batch.spec_info.kv_indptr = (
                forward_batch.spec_info.kv_indptr.clone()
            )
            forward_batch.spec_info.kv_indices = (
                forward_batch.spec_info.kv_indices.clone()
            )
            self.attn_backends[i].init_forward_metadata(forward_batch)

        self.common_template(forward_batch, kv_indices, call_fn)

832
    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
833
        self.cuda_graph_kv_indices = torch.zeros(
834
            (self.speculative_num_steps, max_num_tokens * self.max_context_len),
835
            dtype=torch.int32,
836
            device=self.device,
837
838
839
        )
        for i in range(self.speculative_num_steps):
            self.attn_backends[i].init_cuda_graph_state(
840
                max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
            )

    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
                forward_batch.batch_size,
                forward_batch.batch_size * self.topk,
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)

857
858
859
    def init_forward_metadata_replay_cuda_graph(
        self, forward_batch: ForwardBatch, bs: int
    ):
860
861
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
862
                bs,
863
864
865
866
867
868
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                seq_lens_sum=-1,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
869
                seq_lens_cpu=None,
870
871
872
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983


@triton.jit
def get_num_kv_splits_triton(
    num_kv_splits_ptr,
    seq_lens_ptr,
    num_seq,
    num_group,
    num_head,
    num_kv_head,
    max_kv_splits,
    device_core_count,
    MAX_NUM_SEQ: tl.constexpr,
):
    # TODO: this method is tunable, we need more online serving data to tune it
    offs_seq = tl.arange(0, MAX_NUM_SEQ)
    mask_seq = offs_seq < num_seq

    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
    max_seq_len = tl.max(seq_lens)
    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
    min_seq_len = tl.min(seq_lens)
    if max_seq_len * 8 < min_seq_len * 10:
        min_seq_len = max_seq_len
    max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
    kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)

    # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
    ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
    ext_device_core_count = tl.cast(
        device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
    )
    block_h, num_kv_group = 16, num_head // num_kv_head
    if num_kv_group == 1:
        token_grid = num_seq * num_group * num_head
    else:
        # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
        block_h = tl.minimum(block_h, num_kv_group)
        token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
    max_kv_splits_2 = tl.minimum(
        tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
    )
    kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)

    num_kv_splits = tl.maximum(
        tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
    )

    offs_token = offs_seq * num_group
    mask_token = offs_token < num_seq * num_group
    for i in range(0, num_group):
        tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)


def update_sliding_window_buffer(
    window_kv_indptr,
    req_to_token,
    sliding_window_size,
    seq_lens,
    req_pool_indices,
    bs,
    device,
):
    window_kv_lens = torch.minimum(
        seq_lens,
        torch.tensor(sliding_window_size + 1),
    )
    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
    window_kv_indptr = window_kv_indptr[: bs + 1]
    window_kv_indices = torch.empty(
        window_kv_indptr[-1], dtype=torch.int32, device=device
    )
    window_kv_start_idx = seq_lens - window_kv_lens
    create_flashinfer_kv_indices_triton[(bs,)](
        req_to_token,
        req_pool_indices,
        window_kv_lens,
        window_kv_indptr,
        window_kv_start_idx,
        window_kv_indices,
        req_to_token.stride(0),
    )
    return window_kv_indptr, window_kv_indices, window_kv_lens


def update_sliding_window_buffer_cuda_graph(
    window_kv_indptr,
    window_kv_indices,
    req_to_token,
    sliding_window_size,
    seq_lens,
    req_pool_indices,
    bs,
):
    window_kv_lens = torch.minimum(
        seq_lens,
        torch.tensor(sliding_window_size + 1),
    )
    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
    window_kv_indptr = window_kv_indptr[: bs + 1]
    window_kv_start_idx = seq_lens - window_kv_lens
    create_flashinfer_kv_indices_triton[(bs,)](
        req_to_token,
        req_pool_indices,
        window_kv_lens,
        window_kv_indptr,
        window_kv_start_idx,
        window_kv_indices,
        req_to_token.stride(0),
    )
    return window_kv_indptr, window_kv_lens