triton_backend.py 43 KB
Newer Older
1
2
from __future__ import annotations

3
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Optional, Union
5
6

import torch
7
import triton
8
import triton.language as tl
9

10
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
from sglang.srt.layers.dp_attention import get_attention_tp_size
13
from sglang.srt.layers.radix_attention import AttentionType
14
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
16
17
18
19
20
from sglang.srt.utils import (
    get_bool_env_var,
    get_device_core_count,
    get_int_env_var,
    next_power_of_2,
)
21
22

if TYPE_CHECKING:
23
    from sglang.srt.layers.radix_attention import RadixAttention
24
    from sglang.srt.model_executor.model_runner import ModelRunner
25
    from sglang.srt.speculative.spec_info import SpecInput
26
27


28
29
30
31
32
33
34
35
def logit_capping_mod(logit_capping_method, logit_cap):
    # positive logit_cap -> tanh cap
    if logit_capping_method == "tanh":
        return logit_cap
    else:
        raise ValueError()


36
37
38
39
40
41
42
43
44
45
46
@dataclass
class ForwardMetadata:
    attn_logits: torch.Tensor
    attn_lse: torch.Tensor
    max_extend_len: int
    num_kv_splits: torch.Tensor
    kv_indptr: torch.Tensor
    kv_indices: torch.Tensor
    qo_indptr: torch.Tensor
    custom_mask: torch.Tensor
    mask_indptr: torch.Tensor
47
48
49
50
    # Sliding window
    window_kv_indptr: torch.Tensor
    window_kv_indices: torch.Tensor
    window_num_kv_splits: torch.Tensor
51
    window_kv_offsets: torch.Tensor
52
53


54
class TritonAttnBackend(AttentionBackend):
55
56
57
58
59
60
    def __init__(
        self,
        model_runner: ModelRunner,
        skip_prefill: bool = False,
        kv_indptr_buf: Optional[torch.Tensor] = None,
    ):
61
        # Lazy import to avoid the initialization of cuda context
62
63
64
65
66
67
68
69
70
        from sglang.srt.layers.attention.triton_ops.decode_attention import (
            decode_attention_fwd,
        )
        from sglang.srt.layers.attention.triton_ops.extend_attention import (
            extend_attention_fwd,
        )

        super().__init__()

71
72
        self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
        self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
Ke Bao's avatar
Ke Bao committed
73

Lianmin Zheng's avatar
Lianmin Zheng committed
74
        # Parse args
75
        self.skip_prefill = skip_prefill
76
        max_bs = model_runner.req_to_token_pool.size
Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
79
80
81
82
83
84
85
86
87
        self.sliding_window_size = model_runner.sliding_window_size
        self.req_to_token = model_runner.req_to_token_pool.req_to_token
        self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
        self.speculative_num_steps = model_runner.server_args.speculative_num_steps
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
        self.num_kv_head = model_runner.model_config.get_num_kv_heads(
            get_attention_tp_size()
        )
88
        if model_runner.hybrid_gdn_config is not None:
89
90
91
92
93
94
            # For hybrid linear models, layer_id = 0 may not be full attention
            self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
        else:
            self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
                -1
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
95
96
97
98
99
100
101
        self.max_context_len = model_runner.model_config.context_len
        self.device = model_runner.device
        self.device_core_count = get_device_core_count(model_runner.gpu_id)
        self.static_kv_splits = get_bool_env_var(
            "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
        )
        self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

        # Decide whether enable deterministic inference with batch-invariant operations
        self.enable_deterministic = (
            model_runner.server_args.enable_deterministic_inference
        )

        # Configure deterministic inference settings
        if self.enable_deterministic:
            # Use fixed split tile size for batch invariance
            self.split_tile_size = get_int_env_var(
                "SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
            )
            # Set static_kv_splits to False to use deterministic logic instead
            self.static_kv_splits = False
        else:
            self.split_tile_size = (
                model_runner.server_args.triton_attention_split_tile_size
            )

121
122
123
124
        if self.split_tile_size is not None:
            self.max_kv_splits = (
                self.max_context_len + self.split_tile_size - 1
            ) // self.split_tile_size
125

Lianmin Zheng's avatar
Lianmin Zheng committed
126
        # Check arguments
127
128
129
130
131
        assert not (
            model_runner.sliding_window_size is not None
            and model_runner.model_config.is_encoder_decoder
        ), "Sliding window and cross attention are not supported together"

Lianmin Zheng's avatar
Lianmin Zheng committed
132
        # Initialize buffers
133
        # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
134
135
136
137
138
139
140
        if kv_indptr_buf is None:
            self.kv_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )
        else:
            self.kv_indptr = kv_indptr_buf

141
142
143
144
145
146
147
148
149
150
151
152
        # If sliding window is enabled, we might need two sets of buffers
        # because of interleaved attention types (e.g. for Gemma3)
        self.window_kv_indptr = None
        if self.sliding_window_size is not None and self.sliding_window_size > 0:
            if kv_indptr_buf is None:
                self.window_kv_indptr = torch.zeros(
                    (max_bs + 1,), dtype=torch.int32, device=model_runner.device
                )
            else:
                # When provided a buffer, create a clone for the second buffer
                self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)

153
154
155
156
157
158
159
160
        if not self.skip_prefill:
            self.qo_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )

            self.mask_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int64, device=model_runner.device
            )
161

Lianmin Zheng's avatar
Lianmin Zheng committed
162
        # Initialize forward metadata
163
        self.forward_metadata: ForwardMetadata = None
164

165
166
        self.cuda_graph_custom_mask = None

167
168
169
170
171
    def get_num_kv_splits(
        self,
        num_kv_splits: torch.Tensor,
        seq_lens: torch.Tensor,
    ):
172
        num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
173
174
175
        # NOTE(alcanderian): Considering speculative_decodeing,
        # num_kv_splits.shape[0] will be topk * real_num_token.
        # And the real_num_token is num_seq in decoding phase.
176
177
178
179
180
181
        num_group = num_token // num_seq

        assert (
            num_group * num_seq == num_token
        ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"

182
183
184
185
        # Legacy dynamic splitting logic (non-deterministic)
        if (
            self.static_kv_splits or self.device_core_count <= 0
        ) and not self.enable_deterministic:
186
187
188
            num_kv_splits.fill_(self.max_kv_splits)
            return

189
190
191
192
193
194
195
196
        # deterministic
        if self.split_tile_size is not None and self.enable_deterministic:
            # expand seq_lens to match num_token
            if num_group > 1:
                expanded_seq_lens = seq_lens.repeat_interleave(num_group)
            else:
                expanded_seq_lens = seq_lens

197
            num_kv_splits[:] = (
198
                expanded_seq_lens + self.split_tile_size - 1
199
200
201
            ) // self.split_tile_size
            return

202
203
204
205
206
        if num_seq < 256:
            SCHEDULE_SEQ = 256
        else:
            SCHEDULE_SEQ = triton.next_power_of_2(num_seq)

207
208
209
        get_num_kv_splits_triton[(1,)](
            num_kv_splits,
            seq_lens,
210
211
            num_seq,
            num_group,
212
            self.num_head,
213
            self.num_kv_head,
214
215
            self.max_kv_splits,
            self.device_core_count,
216
            MAX_NUM_SEQ=SCHEDULE_SEQ,
217
        )
218

219
220
221
    def init_forward_metadata(self, forward_batch: ForwardBatch):
        """Init auxiliary variables for triton attention backend."""

222
223
        bs = forward_batch.batch_size
        kv_indptr = self.kv_indptr
224
225
226
        window_kv_indptr = self.window_kv_indptr
        window_kv_indices = None
        window_num_kv_splits = None
227
        window_kv_offsets = None
228
229
230
231
232
233
        spec_info = forward_batch.spec_info

        if forward_batch.forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
234
                kv_indices = torch.empty(
235
                    forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device
236
237
238
239
240
241
242
243
244
245
                )
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
246
247
248
249
250
                # Sliding window
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
251
                    window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
252
253
254
255
256
257
258
259
                        update_sliding_window_buffer(
                            self.window_kv_indptr,
                            self.req_to_token,
                            self.sliding_window_size,
                            forward_batch.seq_lens,
                            forward_batch.req_pool_indices,
                            bs,
                            self.device,
260
                            self.token_to_kv_pool_allocator,
261
262
263
264
265
266
                        )
                    )
                    window_num_kv_splits = torch.empty(
                        (bs,), dtype=torch.int32, device=self.device
                    )
                    self.get_num_kv_splits(window_num_kv_splits, window_kv_lens)
267
268
269
270
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
                bs = kv_indptr.shape[0] - 1

271
272
273
274
275
276
277
278
279
280
            attn_logits = torch.empty(
                (bs, self.num_head, self.max_kv_splits, self.v_head_dim),
                dtype=torch.float32,
                device=self.device,
            )
            attn_lse = torch.empty(
                (bs, self.num_head, self.max_kv_splits),
                dtype=torch.float32,
                device=self.device,
            )
281
            num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
282
            self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
283

284
285
286
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
287
            max_extend_len = None
288
289
290
291
292
293
294
295
296
297
        elif forward_batch.forward_mode.is_target_verify():
            bs = len(forward_batch.req_pool_indices)
            qo_indptr = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            # Different with flashinfer kv_indptr and kv_indices construction
298
299
            kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
            kv_indptr = kv_indptr[: bs + 1]
300
            kv_indices = torch.empty(
301
                kv_indptr[-1], dtype=torch.int64, device=self.device
302
303
            )
            create_flashinfer_kv_indices_triton[(bs,)](
304
                self.req_to_token,
305
306
307
308
309
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                kv_indptr,
                None,
                kv_indices,
310
                self.req_to_token.stride(0),
311
312
            )

313
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                # window_kv_offsets is used to calculate the start position in custom mask
                (
                    window_kv_indptr,
                    window_kv_indices,
                    window_kv_lens,
                    window_kv_offsets,
                ) = update_sliding_window_buffer(
                    self.window_kv_indptr,
                    self.req_to_token,
                    self.sliding_window_size,
                    forward_batch.seq_lens,
                    forward_batch.req_pool_indices,
                    bs,
                    self.device,
                    self.token_to_kv_pool_allocator,
329
330
                )

331
332
333
334
335
336
337
338
            custom_mask = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (
                forward_batch.seq_lens + self.num_draft_tokens
            )
            mask_indptr = self.mask_indptr
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
            mask_indptr = mask_indptr[: bs + 1]
            max_extend_len = self.num_draft_tokens
339
            num_kv_splits = None
340
            attn_logits = None
341
            attn_lse = None
342

343
344
345
346
347
        elif forward_batch.forward_mode.is_draft_extend():
            kv_indices, kv_indptr, qo_indptr, custom_mask = (
                spec_info.generate_attn_arg_prefill(
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
348
                    None,
349
350
351
                    self.req_to_token,
                )
            )
352
            kv_indices = kv_indices.to(torch.int64)
353
            mask_indptr = None
354
355
356
            # TODO(FIXME): This will trigger an invalid Eagle tree when using
            # `max(spec_info.accept_length_cpu)`.
            # It might have been forgotten to update somewhere.
357
            max_extend_len = torch.max(spec_info.accept_length).item()
358
            num_kv_splits = None
359
            attn_logits = None
360
            attn_lse = None
361
        else:
362
363
364
365
            kv_indptr[1 : bs + 1] = torch.cumsum(
                forward_batch.extend_prefix_lens, dim=0
            )
            kv_indptr = kv_indptr[: bs + 1]
366
            kv_indices = torch.empty(
367
                forward_batch.extend_prefix_lens.sum().item(),
368
                dtype=torch.int64,
369
370
371
372
373
374
375
376
377
378
379
                device=self.device,
            )
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                forward_batch.req_pool_indices,
                forward_batch.extend_prefix_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
380
381
            # Sliding window
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
382
383
384
385
386
387
388
389
390
391
392
                window_kv_indptr, window_kv_indices, _, _ = (
                    update_sliding_window_buffer(
                        self.window_kv_indptr,
                        self.req_to_token,
                        self.sliding_window_size,
                        forward_batch.extend_prefix_lens,
                        forward_batch.req_pool_indices,
                        bs,
                        self.device,
                        self.token_to_kv_pool_allocator,
                    )
393
                )
394
395
396
397
398

            qo_indptr = self.qo_indptr
            qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
            qo_indptr = qo_indptr[: bs + 1]
            custom_mask = None
399
            mask_indptr = None
400
            attn_logits = None
401
            attn_lse = None
Lianmin Zheng's avatar
Lianmin Zheng committed
402
            max_extend_len = max(forward_batch.extend_seq_lens_cpu)
403
            num_kv_splits = None
404

405
        self.forward_metadata = ForwardMetadata(
406
            attn_logits,
407
            attn_lse,
408
            max_extend_len,
409
            num_kv_splits,
410
411
412
413
            kv_indptr,
            kv_indices,
            qo_indptr,
            custom_mask,
414
            mask_indptr,
415
416
417
            window_kv_indptr,
            window_kv_indices,
            window_num_kv_splits,
418
            window_kv_offsets,
419
        )
420

421
    def init_cuda_graph_state(
422
423
424
425
        self,
        max_bs: int,
        max_num_tokens: int,
        kv_indices_buf: Optional[torch.Tensor] = None,
426
    ):
427
        self.cuda_graph_attn_logits = torch.zeros(
428
            (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
429
430
431
432
            dtype=torch.float32,
            device=self.device,
        )
        self.cuda_graph_attn_lse = torch.zeros(
433
            (max_num_tokens, self.num_head, self.max_kv_splits),
434
435
436
            dtype=torch.float32,
            device=self.device,
        )
437
        self.cuda_graph_num_kv_splits = torch.full(
438
            (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
439
        )
440
441
        if kv_indices_buf is None:
            self.cuda_graph_kv_indices = torch.zeros(
442
                (max_num_tokens * self.max_context_len),
443
                dtype=torch.int64,
444
445
446
447
448
449
450
                device=self.device,
            )
        else:
            self.cuda_graph_kv_indices = kv_indices_buf

        if not self.skip_prefill:
            self.cuda_graph_custom_mask = torch.zeros(
451
                (max_num_tokens * self.max_context_len),
452
453
454
                dtype=torch.uint8,
                device=self.device,
            )
455

456
457
458
        if self.sliding_window_size is not None and self.sliding_window_size > 0:
            if kv_indices_buf is None:
                self.cuda_graph_window_kv_indices = torch.zeros(
459
                    (max_num_tokens * self.sliding_window_size),
460
                    dtype=torch.int64,
461
462
463
464
465
466
                    device=self.device,
                )
            else:
                self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)

            self.cuda_graph_window_num_kv_splits = torch.full(
467
468
469
470
                (max_num_tokens,),
                self.max_kv_splits,
                dtype=torch.int32,
                device=self.device,
471
472
            )

473
474
475
476
477
478
            self.cuda_graph_window_kv_offsets = torch.zeros(
                (max_bs,),
                dtype=torch.int32,
                device=self.device,
            )

479
    def init_forward_metadata_capture_cuda_graph(
480
481
        self,
        bs: int,
482
        num_tokens: int,
483
484
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
485
        encoder_lens: Optional[torch.Tensor],
486
        forward_mode: ForwardMode,
487
        spec_info: Optional[SpecInput],
488
    ):
489
        assert encoder_lens is None, "Not supported"
490
491
492
        window_kv_indptr = self.window_kv_indptr
        window_kv_indices = None
        window_num_kv_splits = None
493
        window_kv_offsets = None
494

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        if forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr = self.kv_indptr
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                kv_indices = self.cuda_graph_kv_indices
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices,
                    seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
510
511
512
513
514
515
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
                    window_kv_indices = self.cuda_graph_window_kv_indices
                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits
516
                    window_kv_indptr, window_kv_indices, _, _ = (
517
518
519
520
521
522
523
524
525
526
                        update_sliding_window_buffer_cuda_graph(
                            self.window_kv_indptr,
                            window_kv_indices,
                            self.req_to_token,
                            self.sliding_window_size,
                            seq_lens[:bs],
                            req_pool_indices,
                            bs,
                            self.token_to_kv_pool_allocator,
                        )
527
                    )
528
529
530
531
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices

            attn_logits = self.cuda_graph_attn_logits
532
            attn_lse = self.cuda_graph_attn_lse
533
            max_extend_len = None
534
            num_kv_splits = self.cuda_graph_num_kv_splits
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
        elif forward_mode.is_target_verify():
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )

560
561
562
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
                window_kv_indices = self.cuda_graph_window_kv_indices
                window_num_kv_splits = self.cuda_graph_window_num_kv_splits
563
564
                window_kv_offsets = self.cuda_graph_window_kv_offsets
                window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
565
566
567
568
569
                    update_sliding_window_buffer_cuda_graph(
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
570
                        seq_lens[:bs],
571
572
573
574
575
576
                        req_pool_indices,
                        bs,
                        self.token_to_kv_pool_allocator,
                    )
                )

577
            custom_mask = self.cuda_graph_custom_mask
578
            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
579
580
581
582
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
            max_extend_len = self.num_draft_tokens
583
            num_kv_splits = None
584
            attn_logits = None
585
            attn_lse = None
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        elif forward_mode.is_draft_extend():
            num_tokens_per_bs = self.speculative_num_steps + 1
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                bs * num_tokens_per_bs + 1,
                step=num_tokens_per_bs,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
            custom_mask = None
            mask_indptr = None
            max_extend_len = num_tokens_per_bs
            num_kv_splits = None
            attn_logits = None
            attn_lse = None
614
615
616
617
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
            )
618

619
        self.forward_metadata = ForwardMetadata(
620
            attn_logits,
621
            attn_lse,
622
            max_extend_len,
623
            num_kv_splits,
624
625
            kv_indptr,
            kv_indices,
626
627
628
            qo_indptr,
            custom_mask,
            mask_indptr,
629
630
631
            window_kv_indptr,
            window_kv_indices,
            window_num_kv_splits,
632
            window_kv_offsets,
633
634
635
        )

    def init_forward_metadata_replay_cuda_graph(
636
637
638
639
640
        self,
        bs: int,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
641
        encoder_lens: Optional[torch.Tensor],
642
        forward_mode: ForwardMode,
643
        spec_info: Optional[SpecInput],
644
        seq_lens_cpu: Optional[torch.Tensor],
645
    ):
646
        # NOTE: encoder_lens expected to be zeros or None
647
648
649
650
        if forward_mode.is_decode_or_idle():
            # Update kv_indptr, kv_indices
            kv_indptr = self.kv_indptr
            kv_indices = self.cuda_graph_kv_indices
651
            num_kv_splits = self.cuda_graph_num_kv_splits
652
653
654
655
656
657
658
659
660
661
662
663
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices[:bs],
                    seq_lens[:bs],
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
664
                num_token = bs
665
666
667
668
669
670
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits
                    window_kv_indices = self.cuda_graph_window_kv_indices
671
                    _, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
672
673
674
675
676
677
678
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
                        seq_lens[:bs],
                        req_pool_indices[:bs],
                        bs,
679
                        self.token_to_kv_pool_allocator,
680
681
682
683
684
                    )
                    self.get_num_kv_splits(
                        window_num_kv_splits[:num_token], window_kv_lens[:bs]
                    )

685
686
687
            else:
                kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
                kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
688
689
                num_token = spec_info.kv_indptr.shape[0] - 1
            self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
690

691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
        elif forward_mode.is_target_verify():
            # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
            bs = len(req_pool_indices)
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
714
715
716
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
                window_num_kv_splits = self.cuda_graph_window_num_kv_splits
                window_kv_indices = self.cuda_graph_window_kv_indices
717
718
719
720
721
722
723
724
725
726
727
728
                window_kv_offsets = self.cuda_graph_window_kv_offsets
                _, _, window_kv_lens, window_kv_offsets[:bs] = (
                    update_sliding_window_buffer_cuda_graph(
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
                        seq_lens[:bs],
                        req_pool_indices,
                        bs,
                        self.token_to_kv_pool_allocator,
                    )
729
                )
730
731
732
733
734
            custom_mask = self.cuda_graph_custom_mask
            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        elif forward_mode.is_draft_extend():
            seq_lens = seq_lens[:bs]
            accept_lens = spec_info.accept_length[:bs]
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
752
753
754
755
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
            )
756

757
758
759
    def get_cuda_graph_seq_len_fill_value(self):
        return 1

760
761
762
763
764
765
766
767
768
769
770
771
772
    def get_verify_buffers_to_fill_after_draft(self):
        """
        Return buffers for verify attention kernels that needs to be filled after draft.

        Typically, these are tree mask and position buffers.
        """
        return [self.cuda_graph_custom_mask, None]

    def update_verify_buffers_to_fill_after_draft(
        self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
    ):
        pass

773
    def forward_extend(
774
        self,
775
776
777
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
778
779
780
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
Ke Bao's avatar
Ke Bao committed
781
        sinks=None,
782
    ):
783
784
785
786
787
788
        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

789
790
791
792
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
793

794
795
        logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)

796
797
798
799
        causal = True
        if layer.attn_type == AttentionType.ENCODER_ONLY:
            causal = False

800
801
802
803
804
805
        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
            sliding_window_size = (
                layer.sliding_window_size
            )  # Needed for sliding window mask
            kv_indptr = self.forward_metadata.window_kv_indptr
            kv_indices = self.forward_metadata.window_kv_indices
806
            window_kv_offsets = self.forward_metadata.window_kv_offsets
807
808
809
810
        else:
            sliding_window_size = -1
            kv_indptr = self.forward_metadata.kv_indptr
            kv_indices = self.forward_metadata.kv_indices
811
            window_kv_offsets = None
812

813
814
815
816
817
818
819
        self.extend_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            k.contiguous(),
            v.contiguous(),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
820
            self.forward_metadata.qo_indptr,
821
822
            kv_indptr,
            kv_indices,
823
            self.forward_metadata.custom_mask,
824
            causal,
825
826
            self.forward_metadata.mask_indptr,
            self.forward_metadata.max_extend_len,
827
            layer.scaling,
828
            logit_cap=logits_soft_cap,
829
            sliding_window_size=sliding_window_size,
Ke Bao's avatar
Ke Bao committed
830
            sinks=sinks,
831
            window_kv_offsets=window_kv_offsets,
832
            xai_temperature_len=layer.xai_temperature_len,
833
834
835
        )
        return o

836
    def forward_decode(
837
        self,
838
839
840
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
841
842
843
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
Ke Bao's avatar
Ke Bao committed
844
        sinks=None,
845
    ):
846
847
848
849
850
851
852
853
854
855
        # During torch.compile, there is a bug in rotary_emb that causes the
        # output value to have a 3D tensor shape. This reshapes the output correctly.
        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

856
857
        logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)

858
859
860
861
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
862

863
864
865
866
867
868
869
        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
            kv_indptr = self.forward_metadata.window_kv_indptr
            kv_indices = self.forward_metadata.window_kv_indices
        else:
            kv_indptr = self.forward_metadata.kv_indptr
            kv_indices = self.forward_metadata.kv_indices

870
871
872
873
874
        self.decode_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
875
876
            kv_indptr,
            kv_indices,
877
878
879
            self.forward_metadata.attn_logits,
            self.forward_metadata.attn_lse,
            self.forward_metadata.num_kv_splits,
880
            self.max_kv_splits,
881
            layer.scaling,
882
            logit_cap=logits_soft_cap,
Ke Bao's avatar
Ke Bao committed
883
            sinks=sinks,
884
            xai_temperature_len=layer.xai_temperature_len,
885
886
        )
        return o
887
888
889
890
891
892
893
894
895
896
897
898
899
900


class TritonMultiStepDraftBackend:
    """
    Wrap multiple triton attention backends as one for multiple consecutive
    draft decoding steps.
    """

    def __init__(
        self,
        model_runner: ModelRunner,
        topk: int,
        speculative_num_steps: int,
    ):
901
        from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
902
903
904
905

        self.topk = topk
        self.speculative_num_steps = speculative_num_steps
        self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
906
        max_bs = model_runner.req_to_token_pool.size * self.topk
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
        self.kv_indptr = torch.zeros(
            (
                self.speculative_num_steps,
                max_bs + 1,
            ),
            dtype=torch.int32,
            device=model_runner.device,
        )
        self.attn_backends = []
        for i in range(self.speculative_num_steps):
            self.attn_backends.append(
                TritonAttnBackend(
                    model_runner,
                    skip_prefill=True,
                    kv_indptr_buf=self.kv_indptr[i],
                )
            )
        self.max_context_len = self.attn_backends[0].max_context_len
925
926
927
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
928
        self.device = model_runner.device
929
930
        # Cached variables for generate_draft_decode_kv_indices
        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
931
        self.page_size = model_runner.server_args.page_size
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951

    def common_template(
        self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
    ):
        num_seqs = forward_batch.batch_size
        bs = self.topk * num_seqs
        seq_lens_sum = forward_batch.seq_lens_sum

        self.generate_draft_decode_kv_indices[
            (self.speculative_num_steps, num_seqs, self.topk)
        ](
            forward_batch.req_pool_indices,
            forward_batch.req_to_token_pool.req_to_token,
            forward_batch.seq_lens,
            kv_indices_buffer,
            self.kv_indptr,
            forward_batch.positions,
            self.pool_len,
            kv_indices_buffer.shape[1],
            self.kv_indptr.shape[1],
952
953
954
            next_power_of_2(num_seqs),
            next_power_of_2(self.speculative_num_steps),
            next_power_of_2(bs),
955
            self.page_size,
956
957
958
959
960
961
962
963
964
965
        )

        for i in range(self.speculative_num_steps):
            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
                : seq_lens_sum * self.topk + bs * (i + 1)
            ]
            call_fn(i, forward_batch)

    def init_forward_metadata(self, forward_batch: ForwardBatch):
966
        kv_indices = torch.empty(
967
968
969
970
            (
                self.speculative_num_steps,
                forward_batch.batch_size * self.topk * self.max_context_len,
            ),
971
            dtype=torch.int64,
972
            device=self.device,
973
974
975
976
977
978
979
980
981
982
983
984
985
        )

        def call_fn(i, forward_batch):
            forward_batch.spec_info.kv_indptr = (
                forward_batch.spec_info.kv_indptr.clone()
            )
            forward_batch.spec_info.kv_indices = (
                forward_batch.spec_info.kv_indices.clone()
            )
            self.attn_backends[i].init_forward_metadata(forward_batch)

        self.common_template(forward_batch, kv_indices, call_fn)

986
    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
987
        self.cuda_graph_kv_indices = torch.zeros(
988
            (self.speculative_num_steps, max_num_tokens * self.max_context_len),
989
            dtype=torch.int64,
990
            device=self.device,
991
992
993
        )
        for i in range(self.speculative_num_steps):
            self.attn_backends[i].init_cuda_graph_state(
994
                max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
            )

    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
                forward_batch.batch_size,
                forward_batch.batch_size * self.topk,
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)

1011
1012
1013
    def init_forward_metadata_replay_cuda_graph(
        self, forward_batch: ForwardBatch, bs: int
    ):
1014
1015
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1016
                bs,
1017
1018
1019
1020
1021
1022
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                seq_lens_sum=-1,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
1023
                seq_lens_cpu=None,
1024
1025
1026
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088


@triton.jit
def get_num_kv_splits_triton(
    num_kv_splits_ptr,
    seq_lens_ptr,
    num_seq,
    num_group,
    num_head,
    num_kv_head,
    max_kv_splits,
    device_core_count,
    MAX_NUM_SEQ: tl.constexpr,
):
    # TODO: this method is tunable, we need more online serving data to tune it
    offs_seq = tl.arange(0, MAX_NUM_SEQ)
    mask_seq = offs_seq < num_seq

    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
    max_seq_len = tl.max(seq_lens)
    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
    min_seq_len = tl.min(seq_lens)
    if max_seq_len * 8 < min_seq_len * 10:
        min_seq_len = max_seq_len
    max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
    kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)

    # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
    ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
    ext_device_core_count = tl.cast(
        device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
    )
    block_h, num_kv_group = 16, num_head // num_kv_head
    if num_kv_group == 1:
        token_grid = num_seq * num_group * num_head
    else:
        # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
        block_h = tl.minimum(block_h, num_kv_group)
        token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
    max_kv_splits_2 = tl.minimum(
        tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
    )
    kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)

    num_kv_splits = tl.maximum(
        tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
    )

    offs_token = offs_seq * num_group
    mask_token = offs_token < num_seq * num_group
    for i in range(0, num_group):
        tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)


def update_sliding_window_buffer(
    window_kv_indptr,
    req_to_token,
    sliding_window_size,
    seq_lens,
    req_pool_indices,
    bs,
    device,
1089
    token_to_kv_pool_allocator=None,
1090
1091
1092
):
    window_kv_lens = torch.minimum(
        seq_lens,
1093
        torch.tensor(sliding_window_size),
1094
1095
1096
1097
    )
    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
    window_kv_indptr = window_kv_indptr[: bs + 1]
    window_kv_indices = torch.empty(
1098
        window_kv_indptr[-1], dtype=torch.int64, device=device
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
    )
    window_kv_start_idx = seq_lens - window_kv_lens
    create_flashinfer_kv_indices_triton[(bs,)](
        req_to_token,
        req_pool_indices,
        window_kv_lens,
        window_kv_indptr,
        window_kv_start_idx,
        window_kv_indices,
        req_to_token.stride(0),
    )
1110
1111
1112
1113
1114
1115
1116
1117
    # full to swa index mapping
    if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
        kv_last_index = window_kv_indptr[-1]
        window_kv_indices[:kv_last_index] = (
            token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
                window_kv_indices[:kv_last_index]
            )
        )
1118
    return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128


def update_sliding_window_buffer_cuda_graph(
    window_kv_indptr,
    window_kv_indices,
    req_to_token,
    sliding_window_size,
    seq_lens,
    req_pool_indices,
    bs,
1129
    token_to_kv_pool_allocator=None,
1130
1131
1132
):
    window_kv_lens = torch.minimum(
        seq_lens,
1133
        torch.tensor(sliding_window_size),
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
    )
    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
    window_kv_indptr = window_kv_indptr[: bs + 1]
    window_kv_start_idx = seq_lens - window_kv_lens
    create_flashinfer_kv_indices_triton[(bs,)](
        req_to_token,
        req_pool_indices,
        window_kv_lens,
        window_kv_indptr,
        window_kv_start_idx,
        window_kv_indices,
        req_to_token.stride(0),
    )
1147
1148
1149
1150
1151
1152
1153
1154
    # full to swa index mapping
    if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
        kv_last_index = window_kv_indptr[-1]
        window_kv_indices[:kv_last_index] = (
            token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
                window_kv_indices[:kv_last_index]
            )
        )
1155
    return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx