triton_backend.py 40.3 KB
Newer Older
1
2
from __future__ import annotations

3
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Optional, Union
5
6

import torch
7
import triton
8
import triton.language as tl
9

10
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
from sglang.srt.layers.dp_attention import get_attention_tp_size
13
from sglang.srt.layers.radix_attention import AttentionType
14
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
16
17

if TYPE_CHECKING:
18
    from sglang.srt.layers.radix_attention import RadixAttention
19
    from sglang.srt.model_executor.model_runner import ModelRunner
20
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
21
22


23
24
25
26
27
28
29
30
31
32
33
@dataclass
class ForwardMetadata:
    attn_logits: torch.Tensor
    attn_lse: torch.Tensor
    max_extend_len: int
    num_kv_splits: torch.Tensor
    kv_indptr: torch.Tensor
    kv_indices: torch.Tensor
    qo_indptr: torch.Tensor
    custom_mask: torch.Tensor
    mask_indptr: torch.Tensor
34
35
36
37
    # Sliding window
    window_kv_indptr: torch.Tensor
    window_kv_indices: torch.Tensor
    window_num_kv_splits: torch.Tensor
38
    window_kv_offsets: torch.Tensor
39
40


41
class TritonAttnBackend(AttentionBackend):
42
43
44
45
46
47
    def __init__(
        self,
        model_runner: ModelRunner,
        skip_prefill: bool = False,
        kv_indptr_buf: Optional[torch.Tensor] = None,
    ):
48
        # Lazy import to avoid the initialization of cuda context
49
50
51
52
53
54
55
56
57
        from sglang.srt.layers.attention.triton_ops.decode_attention import (
            decode_attention_fwd,
        )
        from sglang.srt.layers.attention.triton_ops.extend_attention import (
            extend_attention_fwd,
        )

        super().__init__()

58
59
        self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
        self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
Ke Bao's avatar
Ke Bao committed
60

Lianmin Zheng's avatar
Lianmin Zheng committed
61
        # Parse args
62
        self.skip_prefill = skip_prefill
63
        max_bs = model_runner.req_to_token_pool.size
Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        self.sliding_window_size = model_runner.sliding_window_size
        self.req_to_token = model_runner.req_to_token_pool.req_to_token
        self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
        self.speculative_num_steps = model_runner.server_args.speculative_num_steps
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
        self.num_kv_head = model_runner.model_config.get_num_kv_heads(
            get_attention_tp_size()
        )
        self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
        self.max_context_len = model_runner.model_config.context_len
        self.device = model_runner.device
        self.device_core_count = get_device_core_count(model_runner.gpu_id)
        self.static_kv_splits = get_bool_env_var(
            "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
        )
        self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
83

Lianmin Zheng's avatar
Lianmin Zheng committed
84
        # Check arguments
85
86
87
88
89
        assert not (
            model_runner.sliding_window_size is not None
            and model_runner.model_config.is_encoder_decoder
        ), "Sliding window and cross attention are not supported together"

Lianmin Zheng's avatar
Lianmin Zheng committed
90
        # Initialize buffers
91
        # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
92
93
94
95
96
97
98
        if kv_indptr_buf is None:
            self.kv_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )
        else:
            self.kv_indptr = kv_indptr_buf

99
100
101
102
103
104
105
106
107
108
109
110
        # If sliding window is enabled, we might need two sets of buffers
        # because of interleaved attention types (e.g. for Gemma3)
        self.window_kv_indptr = None
        if self.sliding_window_size is not None and self.sliding_window_size > 0:
            if kv_indptr_buf is None:
                self.window_kv_indptr = torch.zeros(
                    (max_bs + 1,), dtype=torch.int32, device=model_runner.device
                )
            else:
                # When provided a buffer, create a clone for the second buffer
                self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)

111
112
113
114
115
116
117
118
        if not self.skip_prefill:
            self.qo_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )

            self.mask_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int64, device=model_runner.device
            )
119

Lianmin Zheng's avatar
Lianmin Zheng committed
120
        # Initialize forward metadata
121
        self.forward_metadata: ForwardMetadata = None
122

123
124
125
126
127
    def get_num_kv_splits(
        self,
        num_kv_splits: torch.Tensor,
        seq_lens: torch.Tensor,
    ):
128
        num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
129
130
131
        # NOTE(alcanderian): Considering speculative_decodeing,
        # num_kv_splits.shape[0] will be topk * real_num_token.
        # And the real_num_token is num_seq in decoding phase.
132
133
134
135
136
137
138
        num_group = num_token // num_seq

        assert (
            num_group * num_seq == num_token
        ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"

        if self.static_kv_splits or self.device_core_count <= 0:
139
140
141
            num_kv_splits.fill_(self.max_kv_splits)
            return

142
143
144
145
146
        if num_seq < 256:
            SCHEDULE_SEQ = 256
        else:
            SCHEDULE_SEQ = triton.next_power_of_2(num_seq)

147
148
149
        get_num_kv_splits_triton[(1,)](
            num_kv_splits,
            seq_lens,
150
151
            num_seq,
            num_group,
152
            self.num_head,
153
            self.num_kv_head,
154
155
            self.max_kv_splits,
            self.device_core_count,
156
            MAX_NUM_SEQ=SCHEDULE_SEQ,
157
        )
158

159
160
161
    def init_forward_metadata(self, forward_batch: ForwardBatch):
        """Init auxiliary variables for triton attention backend."""

162
163
        bs = forward_batch.batch_size
        kv_indptr = self.kv_indptr
164
165
166
        window_kv_indptr = self.window_kv_indptr
        window_kv_indices = None
        window_num_kv_splits = None
167
        window_kv_offsets = None
168
169
170
171
172
173
        spec_info = forward_batch.spec_info

        if forward_batch.forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
174
                kv_indices = torch.empty(
175
176
177
178
179
180
181
182
183
184
185
                    forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
                )
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
186
187
188
189
190
                # Sliding window
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
191
                    window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
192
193
194
195
196
197
198
199
                        update_sliding_window_buffer(
                            self.window_kv_indptr,
                            self.req_to_token,
                            self.sliding_window_size,
                            forward_batch.seq_lens,
                            forward_batch.req_pool_indices,
                            bs,
                            self.device,
200
                            self.token_to_kv_pool_allocator,
201
202
203
204
205
206
                        )
                    )
                    window_num_kv_splits = torch.empty(
                        (bs,), dtype=torch.int32, device=self.device
                    )
                    self.get_num_kv_splits(window_num_kv_splits, window_kv_lens)
207
208
209
210
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
                bs = kv_indptr.shape[0] - 1

211
212
213
214
215
216
217
218
219
220
            attn_logits = torch.empty(
                (bs, self.num_head, self.max_kv_splits, self.v_head_dim),
                dtype=torch.float32,
                device=self.device,
            )
            attn_lse = torch.empty(
                (bs, self.num_head, self.max_kv_splits),
                dtype=torch.float32,
                device=self.device,
            )
221
            num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
222
            self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
223

224
225
226
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
227
            max_extend_len = None
228
229
230
231
232
233
234
235
236
237
        elif forward_batch.forward_mode.is_target_verify():
            bs = len(forward_batch.req_pool_indices)
            qo_indptr = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            # Different with flashinfer kv_indptr and kv_indices construction
238
239
            kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
            kv_indptr = kv_indptr[: bs + 1]
240
            kv_indices = torch.empty(
241
                kv_indptr[-1], dtype=torch.int32, device=self.device
242
243
            )
            create_flashinfer_kv_indices_triton[(bs,)](
244
                self.req_to_token,
245
246
247
248
249
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                kv_indptr,
                None,
                kv_indices,
250
                self.req_to_token.stride(0),
251
252
            )

253
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                # window_kv_offsets is used to calculate the start position in custom mask
                (
                    window_kv_indptr,
                    window_kv_indices,
                    window_kv_lens,
                    window_kv_offsets,
                ) = update_sliding_window_buffer(
                    self.window_kv_indptr,
                    self.req_to_token,
                    self.sliding_window_size,
                    forward_batch.seq_lens,
                    forward_batch.req_pool_indices,
                    bs,
                    self.device,
                    self.token_to_kv_pool_allocator,
269
270
                )

271
272
273
274
275
276
277
278
            custom_mask = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (
                forward_batch.seq_lens + self.num_draft_tokens
            )
            mask_indptr = self.mask_indptr
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
            mask_indptr = mask_indptr[: bs + 1]
            max_extend_len = self.num_draft_tokens
279
            num_kv_splits = None
280
            attn_logits = None
281
            attn_lse = None
282

283
284
285
286
287
        elif forward_batch.forward_mode.is_draft_extend():
            kv_indices, kv_indptr, qo_indptr, custom_mask = (
                spec_info.generate_attn_arg_prefill(
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
288
                    None,
289
290
291
292
                    self.req_to_token,
                )
            )
            mask_indptr = None
293
294
295
            # TODO(FIXME): This will trigger an invalid Eagle tree when using
            # `max(spec_info.accept_length_cpu)`.
            # It might have been forgotten to update somewhere.
296
            max_extend_len = torch.max(spec_info.accept_length).item()
297
            num_kv_splits = None
298
            attn_logits = None
299
            attn_lse = None
300
        else:
301
302
303
304
            kv_indptr[1 : bs + 1] = torch.cumsum(
                forward_batch.extend_prefix_lens, dim=0
            )
            kv_indptr = kv_indptr[: bs + 1]
305
            kv_indices = torch.empty(
306
307
308
309
310
311
312
313
314
315
316
317
318
                forward_batch.extend_prefix_lens.sum().item(),
                dtype=torch.int32,
                device=self.device,
            )
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                forward_batch.req_pool_indices,
                forward_batch.extend_prefix_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
319
320
            # Sliding window
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
321
322
323
324
325
326
327
328
329
330
331
                window_kv_indptr, window_kv_indices, _, _ = (
                    update_sliding_window_buffer(
                        self.window_kv_indptr,
                        self.req_to_token,
                        self.sliding_window_size,
                        forward_batch.extend_prefix_lens,
                        forward_batch.req_pool_indices,
                        bs,
                        self.device,
                        self.token_to_kv_pool_allocator,
                    )
332
                )
333
334
335
336
337

            qo_indptr = self.qo_indptr
            qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
            qo_indptr = qo_indptr[: bs + 1]
            custom_mask = None
338
            mask_indptr = None
339
            attn_logits = None
340
            attn_lse = None
Lianmin Zheng's avatar
Lianmin Zheng committed
341
            max_extend_len = max(forward_batch.extend_seq_lens_cpu)
342
            num_kv_splits = None
343

344
        self.forward_metadata = ForwardMetadata(
345
            attn_logits,
346
            attn_lse,
347
            max_extend_len,
348
            num_kv_splits,
349
350
351
352
            kv_indptr,
            kv_indices,
            qo_indptr,
            custom_mask,
353
            mask_indptr,
354
355
356
            window_kv_indptr,
            window_kv_indices,
            window_num_kv_splits,
357
            window_kv_offsets,
358
        )
359

360
    def init_cuda_graph_state(
361
362
363
364
        self,
        max_bs: int,
        max_num_tokens: int,
        kv_indices_buf: Optional[torch.Tensor] = None,
365
    ):
366
        self.cuda_graph_attn_logits = torch.zeros(
367
            (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
368
369
370
371
            dtype=torch.float32,
            device=self.device,
        )
        self.cuda_graph_attn_lse = torch.zeros(
372
            (max_num_tokens, self.num_head, self.max_kv_splits),
373
374
375
            dtype=torch.float32,
            device=self.device,
        )
376
        self.cuda_graph_num_kv_splits = torch.full(
377
            (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
378
        )
379
380
        if kv_indices_buf is None:
            self.cuda_graph_kv_indices = torch.zeros(
381
                (max_num_tokens * self.max_context_len),
382
383
384
385
386
387
388
389
                dtype=torch.int32,
                device=self.device,
            )
        else:
            self.cuda_graph_kv_indices = kv_indices_buf

        if not self.skip_prefill:
            self.cuda_graph_custom_mask = torch.zeros(
390
                (max_num_tokens * self.max_context_len),
391
392
393
                dtype=torch.uint8,
                device=self.device,
            )
394

395
396
397
        if self.sliding_window_size is not None and self.sliding_window_size > 0:
            if kv_indices_buf is None:
                self.cuda_graph_window_kv_indices = torch.zeros(
398
                    (max_num_tokens * self.sliding_window_size),
399
400
401
402
403
404
405
                    dtype=torch.int32,
                    device=self.device,
                )
            else:
                self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)

            self.cuda_graph_window_num_kv_splits = torch.full(
406
407
408
409
                (max_num_tokens,),
                self.max_kv_splits,
                dtype=torch.int32,
                device=self.device,
410
411
            )

412
413
414
415
416
417
            self.cuda_graph_window_kv_offsets = torch.zeros(
                (max_bs,),
                dtype=torch.int32,
                device=self.device,
            )

418
    def init_forward_metadata_capture_cuda_graph(
419
420
        self,
        bs: int,
421
        num_tokens: int,
422
423
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
424
        encoder_lens: Optional[torch.Tensor],
425
        forward_mode: ForwardMode,
426
        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
427
    ):
428
        assert encoder_lens is None, "Not supported"
429
430
431
        window_kv_indptr = self.window_kv_indptr
        window_kv_indices = None
        window_num_kv_splits = None
432
        window_kv_offsets = None
433

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        if forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr = self.kv_indptr
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                kv_indices = self.cuda_graph_kv_indices
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices,
                    seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
449
450
451
452
453
454
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
                    window_kv_indices = self.cuda_graph_window_kv_indices
                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits
455
                    window_kv_indptr, window_kv_indices, _, _ = (
456
457
458
459
460
461
462
463
464
465
                        update_sliding_window_buffer_cuda_graph(
                            self.window_kv_indptr,
                            window_kv_indices,
                            self.req_to_token,
                            self.sliding_window_size,
                            seq_lens[:bs],
                            req_pool_indices,
                            bs,
                            self.token_to_kv_pool_allocator,
                        )
466
                    )
467
468
469
470
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices

            attn_logits = self.cuda_graph_attn_logits
471
            attn_lse = self.cuda_graph_attn_lse
472
            max_extend_len = None
473
            num_kv_splits = self.cuda_graph_num_kv_splits
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
        elif forward_mode.is_target_verify():
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )

499
500
501
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
                window_kv_indices = self.cuda_graph_window_kv_indices
                window_num_kv_splits = self.cuda_graph_window_num_kv_splits
502
503
                window_kv_offsets = self.cuda_graph_window_kv_offsets
                window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
504
505
506
507
508
                    update_sliding_window_buffer_cuda_graph(
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
509
                        seq_lens[:bs],
510
511
512
513
514
515
                        req_pool_indices,
                        bs,
                        self.token_to_kv_pool_allocator,
                    )
                )

516
            custom_mask = self.cuda_graph_custom_mask
517
            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
518
519
520
521
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
            max_extend_len = self.num_draft_tokens
522
            num_kv_splits = None
523
            attn_logits = None
524
            attn_lse = None
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        elif forward_mode.is_draft_extend():
            num_tokens_per_bs = self.speculative_num_steps + 1
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                bs * num_tokens_per_bs + 1,
                step=num_tokens_per_bs,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
            custom_mask = None
            mask_indptr = None
            max_extend_len = num_tokens_per_bs
            num_kv_splits = None
            attn_logits = None
            attn_lse = None
553
554
555
556
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
            )
557

558
        self.forward_metadata = ForwardMetadata(
559
            attn_logits,
560
            attn_lse,
561
            max_extend_len,
562
            num_kv_splits,
563
564
            kv_indptr,
            kv_indices,
565
566
567
            qo_indptr,
            custom_mask,
            mask_indptr,
568
569
570
            window_kv_indptr,
            window_kv_indices,
            window_num_kv_splits,
571
            window_kv_offsets,
572
573
574
        )

    def init_forward_metadata_replay_cuda_graph(
575
576
577
578
579
        self,
        bs: int,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
580
        encoder_lens: Optional[torch.Tensor],
581
        forward_mode: ForwardMode,
582
        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
583
        seq_lens_cpu: Optional[torch.Tensor],
584
    ):
585
        # NOTE: encoder_lens expected to be zeros or None
586
587
588
589
        if forward_mode.is_decode_or_idle():
            # Update kv_indptr, kv_indices
            kv_indptr = self.kv_indptr
            kv_indices = self.cuda_graph_kv_indices
590
            num_kv_splits = self.cuda_graph_num_kv_splits
591
592
593
594
595
596
597
598
599
600
601
602
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices[:bs],
                    seq_lens[:bs],
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
603
                num_token = bs
604
605
606
607
608
609
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits
                    window_kv_indices = self.cuda_graph_window_kv_indices
610
                    _, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
611
612
613
614
615
616
617
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
                        seq_lens[:bs],
                        req_pool_indices[:bs],
                        bs,
618
                        self.token_to_kv_pool_allocator,
619
620
621
622
623
                    )
                    self.get_num_kv_splits(
                        window_num_kv_splits[:num_token], window_kv_lens[:bs]
                    )

624
625
626
            else:
                kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
                kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
627
628
                num_token = spec_info.kv_indptr.shape[0] - 1
            self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
629

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        elif forward_mode.is_target_verify():
            # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
            bs = len(req_pool_indices)
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
653
654
655
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
                window_num_kv_splits = self.cuda_graph_window_num_kv_splits
                window_kv_indices = self.cuda_graph_window_kv_indices
656
657
658
659
660
661
662
663
664
665
666
667
                window_kv_offsets = self.cuda_graph_window_kv_offsets
                _, _, window_kv_lens, window_kv_offsets[:bs] = (
                    update_sliding_window_buffer_cuda_graph(
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
                        seq_lens[:bs],
                        req_pool_indices,
                        bs,
                        self.token_to_kv_pool_allocator,
                    )
668
                )
669
670
671
672
673
            custom_mask = self.cuda_graph_custom_mask
            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        elif forward_mode.is_draft_extend():
            seq_lens = seq_lens[:bs]
            accept_lens = spec_info.accept_length[:bs]
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
691
692
693
694
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
            )
695

696
697
698
    def get_cuda_graph_seq_len_fill_value(self):
        return 1

699
    def forward_extend(
700
        self,
701
702
703
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
704
705
706
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
Ke Bao's avatar
Ke Bao committed
707
        sinks=None,
708
    ):
709
710
711
712
713
714
        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

715
716
717
718
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
719

720
721
722
723
        causal = True
        if layer.attn_type == AttentionType.ENCODER_ONLY:
            causal = False

724
725
726
727
728
729
        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
            sliding_window_size = (
                layer.sliding_window_size
            )  # Needed for sliding window mask
            kv_indptr = self.forward_metadata.window_kv_indptr
            kv_indices = self.forward_metadata.window_kv_indices
730
            window_kv_offsets = self.forward_metadata.window_kv_offsets
731
732
733
734
        else:
            sliding_window_size = -1
            kv_indptr = self.forward_metadata.kv_indptr
            kv_indices = self.forward_metadata.kv_indices
735
            window_kv_offsets = None
736

737
738
739
740
741
742
743
        self.extend_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            k.contiguous(),
            v.contiguous(),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
744
            self.forward_metadata.qo_indptr,
745
746
            kv_indptr,
            kv_indices,
747
            self.forward_metadata.custom_mask,
748
            causal,
749
750
            self.forward_metadata.mask_indptr,
            self.forward_metadata.max_extend_len,
751
752
            layer.scaling,
            layer.logit_cap,
753
            sliding_window_size=sliding_window_size,
Ke Bao's avatar
Ke Bao committed
754
            sinks=sinks,
755
            window_kv_offsets=window_kv_offsets,
756
757
758
        )
        return o

759
    def forward_decode(
760
        self,
761
762
763
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
764
765
766
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
Ke Bao's avatar
Ke Bao committed
767
        sinks=None,
768
    ):
769
770
771
772
773
774
775
776
777
778
        # During torch.compile, there is a bug in rotary_emb that causes the
        # output value to have a 3D tensor shape. This reshapes the output correctly.
        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

779
780
781
782
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
783

784
785
786
787
788
789
790
        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
            kv_indptr = self.forward_metadata.window_kv_indptr
            kv_indices = self.forward_metadata.window_kv_indices
        else:
            kv_indptr = self.forward_metadata.kv_indptr
            kv_indices = self.forward_metadata.kv_indices

791
792
793
794
795
        self.decode_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
796
797
            kv_indptr,
            kv_indices,
798
799
800
            self.forward_metadata.attn_logits,
            self.forward_metadata.attn_lse,
            self.forward_metadata.num_kv_splits,
801
            self.max_kv_splits,
802
803
            layer.scaling,
            layer.logit_cap,
Ke Bao's avatar
Ke Bao committed
804
            sinks=sinks,
805
806
        )
        return o
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825


class TritonMultiStepDraftBackend:
    """
    Wrap multiple triton attention backends as one for multiple consecutive
    draft decoding steps.
    """

    def __init__(
        self,
        model_runner: ModelRunner,
        topk: int,
        speculative_num_steps: int,
    ):
        from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices

        self.topk = topk
        self.speculative_num_steps = speculative_num_steps
        self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
826
        max_bs = model_runner.req_to_token_pool.size * self.topk
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
        self.kv_indptr = torch.zeros(
            (
                self.speculative_num_steps,
                max_bs + 1,
            ),
            dtype=torch.int32,
            device=model_runner.device,
        )
        self.attn_backends = []
        for i in range(self.speculative_num_steps):
            self.attn_backends.append(
                TritonAttnBackend(
                    model_runner,
                    skip_prefill=True,
                    kv_indptr_buf=self.kv_indptr[i],
                )
            )
        self.max_context_len = self.attn_backends[0].max_context_len
845
846
847
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
848
        self.device = model_runner.device
849
850
        # Cached variables for generate_draft_decode_kv_indices
        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
851
        self.page_size = model_runner.server_args.page_size
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871

    def common_template(
        self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
    ):
        num_seqs = forward_batch.batch_size
        bs = self.topk * num_seqs
        seq_lens_sum = forward_batch.seq_lens_sum

        self.generate_draft_decode_kv_indices[
            (self.speculative_num_steps, num_seqs, self.topk)
        ](
            forward_batch.req_pool_indices,
            forward_batch.req_to_token_pool.req_to_token,
            forward_batch.seq_lens,
            kv_indices_buffer,
            self.kv_indptr,
            forward_batch.positions,
            self.pool_len,
            kv_indices_buffer.shape[1],
            self.kv_indptr.shape[1],
872
873
874
            next_power_of_2(num_seqs),
            next_power_of_2(self.speculative_num_steps),
            next_power_of_2(bs),
875
            self.page_size,
876
877
878
879
880
881
882
883
884
885
        )

        for i in range(self.speculative_num_steps):
            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
                : seq_lens_sum * self.topk + bs * (i + 1)
            ]
            call_fn(i, forward_batch)

    def init_forward_metadata(self, forward_batch: ForwardBatch):
886
        kv_indices = torch.empty(
887
888
889
890
891
            (
                self.speculative_num_steps,
                forward_batch.batch_size * self.topk * self.max_context_len,
            ),
            dtype=torch.int32,
892
            device=self.device,
893
894
895
896
897
898
899
900
901
902
903
904
905
        )

        def call_fn(i, forward_batch):
            forward_batch.spec_info.kv_indptr = (
                forward_batch.spec_info.kv_indptr.clone()
            )
            forward_batch.spec_info.kv_indices = (
                forward_batch.spec_info.kv_indices.clone()
            )
            self.attn_backends[i].init_forward_metadata(forward_batch)

        self.common_template(forward_batch, kv_indices, call_fn)

906
    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
907
        self.cuda_graph_kv_indices = torch.zeros(
908
            (self.speculative_num_steps, max_num_tokens * self.max_context_len),
909
            dtype=torch.int32,
910
            device=self.device,
911
912
913
        )
        for i in range(self.speculative_num_steps):
            self.attn_backends[i].init_cuda_graph_state(
914
                max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
            )

    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
                forward_batch.batch_size,
                forward_batch.batch_size * self.topk,
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)

931
932
933
    def init_forward_metadata_replay_cuda_graph(
        self, forward_batch: ForwardBatch, bs: int
    ):
934
935
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
936
                bs,
937
938
939
940
941
942
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                seq_lens_sum=-1,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
943
                seq_lens_cpu=None,
944
945
946
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008


@triton.jit
def get_num_kv_splits_triton(
    num_kv_splits_ptr,
    seq_lens_ptr,
    num_seq,
    num_group,
    num_head,
    num_kv_head,
    max_kv_splits,
    device_core_count,
    MAX_NUM_SEQ: tl.constexpr,
):
    # TODO: this method is tunable, we need more online serving data to tune it
    offs_seq = tl.arange(0, MAX_NUM_SEQ)
    mask_seq = offs_seq < num_seq

    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
    max_seq_len = tl.max(seq_lens)
    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
    min_seq_len = tl.min(seq_lens)
    if max_seq_len * 8 < min_seq_len * 10:
        min_seq_len = max_seq_len
    max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
    kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)

    # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
    ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
    ext_device_core_count = tl.cast(
        device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
    )
    block_h, num_kv_group = 16, num_head // num_kv_head
    if num_kv_group == 1:
        token_grid = num_seq * num_group * num_head
    else:
        # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
        block_h = tl.minimum(block_h, num_kv_group)
        token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
    max_kv_splits_2 = tl.minimum(
        tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
    )
    kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)

    num_kv_splits = tl.maximum(
        tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
    )

    offs_token = offs_seq * num_group
    mask_token = offs_token < num_seq * num_group
    for i in range(0, num_group):
        tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)


def update_sliding_window_buffer(
    window_kv_indptr,
    req_to_token,
    sliding_window_size,
    seq_lens,
    req_pool_indices,
    bs,
    device,
1009
    token_to_kv_pool_allocator=None,
1010
1011
1012
):
    window_kv_lens = torch.minimum(
        seq_lens,
1013
        torch.tensor(sliding_window_size),
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
    )
    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
    window_kv_indptr = window_kv_indptr[: bs + 1]
    window_kv_indices = torch.empty(
        window_kv_indptr[-1], dtype=torch.int32, device=device
    )
    window_kv_start_idx = seq_lens - window_kv_lens
    create_flashinfer_kv_indices_triton[(bs,)](
        req_to_token,
        req_pool_indices,
        window_kv_lens,
        window_kv_indptr,
        window_kv_start_idx,
        window_kv_indices,
        req_to_token.stride(0),
    )
1030
1031
1032
1033
1034
1035
1036
1037
    # full to swa index mapping
    if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
        kv_last_index = window_kv_indptr[-1]
        window_kv_indices[:kv_last_index] = (
            token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
                window_kv_indices[:kv_last_index]
            )
        )
1038
    return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048


def update_sliding_window_buffer_cuda_graph(
    window_kv_indptr,
    window_kv_indices,
    req_to_token,
    sliding_window_size,
    seq_lens,
    req_pool_indices,
    bs,
1049
    token_to_kv_pool_allocator=None,
1050
1051
1052
):
    window_kv_lens = torch.minimum(
        seq_lens,
1053
        torch.tensor(sliding_window_size),
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
    )
    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
    window_kv_indptr = window_kv_indptr[: bs + 1]
    window_kv_start_idx = seq_lens - window_kv_lens
    create_flashinfer_kv_indices_triton[(bs,)](
        req_to_token,
        req_pool_indices,
        window_kv_lens,
        window_kv_indptr,
        window_kv_start_idx,
        window_kv_indices,
        req_to_token.stride(0),
    )
1067
1068
1069
1070
1071
1072
1073
1074
    # full to swa index mapping
    if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
        kv_last_index = window_kv_indptr[-1]
        window_kv_indices[:kv_last_index] = (
            token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
                window_kv_indices[:kv_last_index]
            )
        )
1075
    return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx